1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnOnnxParser/IOnnxParser.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #endif
10*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER)
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnSerializer/ISerializer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #endif
13*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
14*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #endif
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <HeapProfiling.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/StringUtils.hpp>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker /*
22*89c4ff92SAndroid Build Coastguard Worker * Historically we use the ',' character to separate dimensions in a tensor shape. However, cxxopts will read this
23*89c4ff92SAndroid Build Coastguard Worker * an an array of values which is fine until we have multiple tensors specified. This lumps the values of all shapes
24*89c4ff92SAndroid Build Coastguard Worker * together in a single array and we cannot break it up again. We'll change the vector delimiter to a '.'. We do this
25*89c4ff92SAndroid Build Coastguard Worker * as close as possible to the usage of cxxopts to avoid polluting other possible uses.
26*89c4ff92SAndroid Build Coastguard Worker */
27*89c4ff92SAndroid Build Coastguard Worker #define CXXOPTS_VECTOR_DELIMITER '.'
28*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker #include <cstdlib>
33*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
34*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker namespace
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker
ParseTensorShape(std::istream & stream)39*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape ParseTensorShape(std::istream& stream)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> result;
42*89c4ff92SAndroid Build Coastguard Worker std::string line;
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker while (std::getline(stream, line))
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
47*89c4ff92SAndroid Build Coastguard Worker for (const std::string& token : tokens)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker if (!token.empty())
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker try
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
54*89c4ff92SAndroid Build Coastguard Worker }
55*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception&)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker
ParseCommandLineArgs(int argc,char * argv[],std::string & modelFormat,std::string & modelPath,std::vector<std::string> & inputNames,std::vector<std::string> & inputTensorShapeStrs,std::vector<std::string> & outputNames,std::string & outputPath,bool & isModelBinary)66*89c4ff92SAndroid Build Coastguard Worker int ParseCommandLineArgs(int argc, char* argv[],
67*89c4ff92SAndroid Build Coastguard Worker std::string& modelFormat,
68*89c4ff92SAndroid Build Coastguard Worker std::string& modelPath,
69*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string>& inputNames,
70*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string>& inputTensorShapeStrs,
71*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string>& outputNames,
72*89c4ff92SAndroid Build Coastguard Worker std::string& outputPath, bool& isModelBinary)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker cxxopts::Options options("ArmNNConverter", "Convert a neural network model from provided file to ArmNN format.");
75*89c4ff92SAndroid Build Coastguard Worker try
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker std::string modelFormatDescription("Format of the model file");
78*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
79*89c4ff92SAndroid Build Coastguard Worker modelFormatDescription += ", onnx-binary, onnx-text";
80*89c4ff92SAndroid Build Coastguard Worker #endif
81*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_PARSER)
82*89c4ff92SAndroid Build Coastguard Worker modelFormatDescription += ", tensorflow-binary, tensorflow-text";
83*89c4ff92SAndroid Build Coastguard Worker #endif
84*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
85*89c4ff92SAndroid Build Coastguard Worker modelFormatDescription += ", tflite-binary";
86*89c4ff92SAndroid Build Coastguard Worker #endif
87*89c4ff92SAndroid Build Coastguard Worker modelFormatDescription += ".";
88*89c4ff92SAndroid Build Coastguard Worker options.add_options()
89*89c4ff92SAndroid Build Coastguard Worker ("help", "Display usage information")
90*89c4ff92SAndroid Build Coastguard Worker ("f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
91*89c4ff92SAndroid Build Coastguard Worker ("m,model-path", "Path to model file.", cxxopts::value<std::string>(modelPath))
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker ("i,input-name", "Identifier of the input tensors in the network. "
94*89c4ff92SAndroid Build Coastguard Worker "Each input must be specified separately.",
95*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>(inputNames))
96*89c4ff92SAndroid Build Coastguard Worker ("s,input-tensor-shape",
97*89c4ff92SAndroid Build Coastguard Worker "The shape of the input tensor in the network as a flat array of integers, "
98*89c4ff92SAndroid Build Coastguard Worker "separated by comma. Each input shape must be specified separately after the input name. "
99*89c4ff92SAndroid Build Coastguard Worker "This parameter is optional, depending on the network.",
100*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker ("o,output-name", "Identifier of the output tensor in the network.",
103*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>(outputNames))
104*89c4ff92SAndroid Build Coastguard Worker ("p,output-path",
105*89c4ff92SAndroid Build Coastguard Worker "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker catch (const std::exception& e)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl << options.help() << std::endl;
110*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker try
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker cxxopts::ParseResult result = options.parse(argc, argv);
115*89c4ff92SAndroid Build Coastguard Worker if (result.count("help"))
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker std::cerr << options.help() << std::endl;
118*89c4ff92SAndroid Build Coastguard Worker return EXIT_SUCCESS;
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker // Check for mandatory single options.
121*89c4ff92SAndroid Build Coastguard Worker std::string mandatorySingleParameters[] = { "model-format", "model-path", "output-name", "output-path" };
122*89c4ff92SAndroid Build Coastguard Worker bool somethingsMissing = false;
123*89c4ff92SAndroid Build Coastguard Worker for (auto param : mandatorySingleParameters)
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker if (result.count(param) != 1)
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
128*89c4ff92SAndroid Build Coastguard Worker somethingsMissing = true;
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker // Check at least one "input-name" option.
132*89c4ff92SAndroid Build Coastguard Worker if (result.count("input-name") == 0)
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Parameter \'--" << "input-name" << "\' must be specified at least once." << std::endl;
135*89c4ff92SAndroid Build Coastguard Worker somethingsMissing = true;
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker // If input-tensor-shape is specified then there must be a 1:1 match with input-name.
138*89c4ff92SAndroid Build Coastguard Worker if (result.count("input-tensor-shape") > 0)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker if (result.count("input-tensor-shape") != result.count("input-name"))
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker std::cerr << "When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters "
143*89c4ff92SAndroid Build Coastguard Worker "must be specified." << std::endl;
144*89c4ff92SAndroid Build Coastguard Worker somethingsMissing = true;
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker if (somethingsMissing)
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker std::cerr << options.help() << std::endl;
151*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker catch (const cxxopts::OptionException& e)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker std::cerr << e.what() << std::endl << std::endl;
157*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker
160*89c4ff92SAndroid Build Coastguard Worker if (modelFormat.find("bin") != std::string::npos)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker isModelBinary = true;
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker else if (modelFormat.find("text") != std::string::npos)
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker isModelBinary = false;
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker else
169*89c4ff92SAndroid Build Coastguard Worker {
170*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
171*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker return EXIT_SUCCESS;
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker template<typename T>
178*89c4ff92SAndroid Build Coastguard Worker struct ParserType
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker typedef T parserType;
181*89c4ff92SAndroid Build Coastguard Worker };
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker class ArmnnConverter
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker public:
ArmnnConverter(const std::string & modelPath,const std::vector<std::string> & inputNames,const std::vector<armnn::TensorShape> & inputShapes,const std::vector<std::string> & outputNames,const std::string & outputPath,bool isModelBinary)186*89c4ff92SAndroid Build Coastguard Worker ArmnnConverter(const std::string& modelPath,
187*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::string>& inputNames,
188*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorShape>& inputShapes,
189*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::string>& outputNames,
190*89c4ff92SAndroid Build Coastguard Worker const std::string& outputPath,
191*89c4ff92SAndroid Build Coastguard Worker bool isModelBinary)
192*89c4ff92SAndroid Build Coastguard Worker : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
193*89c4ff92SAndroid Build Coastguard Worker m_ModelPath(modelPath),
194*89c4ff92SAndroid Build Coastguard Worker m_InputNames(inputNames),
195*89c4ff92SAndroid Build Coastguard Worker m_InputShapes(inputShapes),
196*89c4ff92SAndroid Build Coastguard Worker m_OutputNames(outputNames),
197*89c4ff92SAndroid Build Coastguard Worker m_OutputPath(outputPath),
198*89c4ff92SAndroid Build Coastguard Worker m_IsModelBinary(isModelBinary) {}
199*89c4ff92SAndroid Build Coastguard Worker
Serialize()200*89c4ff92SAndroid Build Coastguard Worker bool Serialize()
201*89c4ff92SAndroid Build Coastguard Worker {
202*89c4ff92SAndroid Build Coastguard Worker if (m_NetworkPtr.get() == nullptr)
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker return false;
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker
207*89c4ff92SAndroid Build Coastguard Worker auto serializer(armnnSerializer::ISerializer::Create());
208*89c4ff92SAndroid Build Coastguard Worker
209*89c4ff92SAndroid Build Coastguard Worker serializer->Serialize(*m_NetworkPtr);
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
212*89c4ff92SAndroid Build Coastguard Worker
213*89c4ff92SAndroid Build Coastguard Worker bool retVal = serializer->SaveSerializedToStream(file);
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker return retVal;
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker
218*89c4ff92SAndroid Build Coastguard Worker template <typename IParser>
CreateNetwork()219*89c4ff92SAndroid Build Coastguard Worker bool CreateNetwork ()
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker return CreateNetwork (ParserType<IParser>());
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker private:
225*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_NetworkPtr;
226*89c4ff92SAndroid Build Coastguard Worker std::string m_ModelPath;
227*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_InputNames;
228*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape> m_InputShapes;
229*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> m_OutputNames;
230*89c4ff92SAndroid Build Coastguard Worker std::string m_OutputPath;
231*89c4ff92SAndroid Build Coastguard Worker bool m_IsModelBinary;
232*89c4ff92SAndroid Build Coastguard Worker
233*89c4ff92SAndroid Build Coastguard Worker template <typename IParser>
CreateNetwork(ParserType<IParser>)234*89c4ff92SAndroid Build Coastguard Worker bool CreateNetwork (ParserType<IParser>)
235*89c4ff92SAndroid Build Coastguard Worker {
236*89c4ff92SAndroid Build Coastguard Worker // Create a network from a file on disk
237*89c4ff92SAndroid Build Coastguard Worker auto parser(IParser::Create());
238*89c4ff92SAndroid Build Coastguard Worker
239*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, armnn::TensorShape> inputShapes;
240*89c4ff92SAndroid Build Coastguard Worker if (!m_InputShapes.empty())
241*89c4ff92SAndroid Build Coastguard Worker {
242*89c4ff92SAndroid Build Coastguard Worker const size_t numInputShapes = m_InputShapes.size();
243*89c4ff92SAndroid Build Coastguard Worker const size_t numInputBindings = m_InputNames.size();
244*89c4ff92SAndroid Build Coastguard Worker if (numInputShapes < numInputBindings)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format(
247*89c4ff92SAndroid Build Coastguard Worker "Not every input has its tensor shape specified: expected={0}, got={1}",
248*89c4ff92SAndroid Build Coastguard Worker numInputBindings, numInputShapes));
249*89c4ff92SAndroid Build Coastguard Worker }
250*89c4ff92SAndroid Build Coastguard Worker
251*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numInputShapes; i++)
252*89c4ff92SAndroid Build Coastguard Worker {
253*89c4ff92SAndroid Build Coastguard Worker inputShapes[m_InputNames[i]] = m_InputShapes[i];
254*89c4ff92SAndroid Build Coastguard Worker }
255*89c4ff92SAndroid Build Coastguard Worker }
256*89c4ff92SAndroid Build Coastguard Worker
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing");
259*89c4ff92SAndroid Build Coastguard Worker m_NetworkPtr = (m_IsModelBinary ?
260*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
261*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
262*89c4ff92SAndroid Build Coastguard Worker }
263*89c4ff92SAndroid Build Coastguard Worker
264*89c4ff92SAndroid Build Coastguard Worker return m_NetworkPtr.get() != nullptr;
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
CreateNetwork(ParserType<armnnTfLiteParser::ITfLiteParser>)268*89c4ff92SAndroid Build Coastguard Worker bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
269*89c4ff92SAndroid Build Coastguard Worker {
270*89c4ff92SAndroid Build Coastguard Worker // Create a network from a file on disk
271*89c4ff92SAndroid Build Coastguard Worker auto parser(armnnTfLiteParser::ITfLiteParser::Create());
272*89c4ff92SAndroid Build Coastguard Worker
273*89c4ff92SAndroid Build Coastguard Worker if (!m_InputShapes.empty())
274*89c4ff92SAndroid Build Coastguard Worker {
275*89c4ff92SAndroid Build Coastguard Worker const size_t numInputShapes = m_InputShapes.size();
276*89c4ff92SAndroid Build Coastguard Worker const size_t numInputBindings = m_InputNames.size();
277*89c4ff92SAndroid Build Coastguard Worker if (numInputShapes < numInputBindings)
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format(
280*89c4ff92SAndroid Build Coastguard Worker "Not every input has its tensor shape specified: expected={0}, got={1}",
281*89c4ff92SAndroid Build Coastguard Worker numInputBindings, numInputShapes));
282*89c4ff92SAndroid Build Coastguard Worker }
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker {
286*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing");
287*89c4ff92SAndroid Build Coastguard Worker m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
288*89c4ff92SAndroid Build Coastguard Worker }
289*89c4ff92SAndroid Build Coastguard Worker
290*89c4ff92SAndroid Build Coastguard Worker return m_NetworkPtr.get() != nullptr;
291*89c4ff92SAndroid Build Coastguard Worker }
292*89c4ff92SAndroid Build Coastguard Worker #endif
293*89c4ff92SAndroid Build Coastguard Worker
294*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
CreateNetwork(ParserType<armnnOnnxParser::IOnnxParser>)295*89c4ff92SAndroid Build Coastguard Worker bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker // Create a network from a file on disk
298*89c4ff92SAndroid Build Coastguard Worker auto parser(armnnOnnxParser::IOnnxParser::Create());
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker if (!m_InputShapes.empty())
301*89c4ff92SAndroid Build Coastguard Worker {
302*89c4ff92SAndroid Build Coastguard Worker const size_t numInputShapes = m_InputShapes.size();
303*89c4ff92SAndroid Build Coastguard Worker const size_t numInputBindings = m_InputNames.size();
304*89c4ff92SAndroid Build Coastguard Worker if (numInputShapes < numInputBindings)
305*89c4ff92SAndroid Build Coastguard Worker {
306*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format(
307*89c4ff92SAndroid Build Coastguard Worker "Not every input has its tensor shape specified: expected={0}, got={1}",
308*89c4ff92SAndroid Build Coastguard Worker numInputBindings, numInputShapes));
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker {
313*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Parsing");
314*89c4ff92SAndroid Build Coastguard Worker m_NetworkPtr = (m_IsModelBinary ?
315*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
316*89c4ff92SAndroid Build Coastguard Worker parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
317*89c4ff92SAndroid Build Coastguard Worker }
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker return m_NetworkPtr.get() != nullptr;
320*89c4ff92SAndroid Build Coastguard Worker }
321*89c4ff92SAndroid Build Coastguard Worker #endif
322*89c4ff92SAndroid Build Coastguard Worker
323*89c4ff92SAndroid Build Coastguard Worker };
324*89c4ff92SAndroid Build Coastguard Worker
325*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
326*89c4ff92SAndroid Build Coastguard Worker
main(int argc,char * argv[])327*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
328*89c4ff92SAndroid Build Coastguard Worker {
329*89c4ff92SAndroid Build Coastguard Worker
330*89c4ff92SAndroid Build Coastguard Worker #if (!defined(ARMNN_ONNX_PARSER) \
331*89c4ff92SAndroid Build Coastguard Worker && !defined(ARMNN_TF_PARSER) \
332*89c4ff92SAndroid Build Coastguard Worker && !defined(ARMNN_TF_LITE_PARSER))
333*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Not built with any of the supported parsers Onnx, Tensorflow, or TfLite.";
334*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
335*89c4ff92SAndroid Build Coastguard Worker #endif
336*89c4ff92SAndroid Build Coastguard Worker
337*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_SERIALIZER)
338*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Not built with Serializer support.";
339*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
340*89c4ff92SAndroid Build Coastguard Worker #endif
341*89c4ff92SAndroid Build Coastguard Worker
342*89c4ff92SAndroid Build Coastguard Worker #ifdef NDEBUG
343*89c4ff92SAndroid Build Coastguard Worker armnn::LogSeverity level = armnn::LogSeverity::Info;
344*89c4ff92SAndroid Build Coastguard Worker #else
345*89c4ff92SAndroid Build Coastguard Worker armnn::LogSeverity level = armnn::LogSeverity::Debug;
346*89c4ff92SAndroid Build Coastguard Worker #endif
347*89c4ff92SAndroid Build Coastguard Worker
348*89c4ff92SAndroid Build Coastguard Worker armnn::ConfigureLogging(true, true, level);
349*89c4ff92SAndroid Build Coastguard Worker
350*89c4ff92SAndroid Build Coastguard Worker std::string modelFormat;
351*89c4ff92SAndroid Build Coastguard Worker std::string modelPath;
352*89c4ff92SAndroid Build Coastguard Worker
353*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> inputNames;
354*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> inputTensorShapeStrs;
355*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape> inputTensorShapes;
356*89c4ff92SAndroid Build Coastguard Worker
357*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> outputNames;
358*89c4ff92SAndroid Build Coastguard Worker std::string outputPath;
359*89c4ff92SAndroid Build Coastguard Worker
360*89c4ff92SAndroid Build Coastguard Worker bool isModelBinary = true;
361*89c4ff92SAndroid Build Coastguard Worker
362*89c4ff92SAndroid Build Coastguard Worker if (ParseCommandLineArgs(
363*89c4ff92SAndroid Build Coastguard Worker argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
364*89c4ff92SAndroid Build Coastguard Worker != EXIT_SUCCESS)
365*89c4ff92SAndroid Build Coastguard Worker {
366*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
367*89c4ff92SAndroid Build Coastguard Worker }
368*89c4ff92SAndroid Build Coastguard Worker
369*89c4ff92SAndroid Build Coastguard Worker for (const std::string& shapeStr : inputTensorShapeStrs)
370*89c4ff92SAndroid Build Coastguard Worker {
371*89c4ff92SAndroid Build Coastguard Worker if (!shapeStr.empty())
372*89c4ff92SAndroid Build Coastguard Worker {
373*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss(shapeStr);
374*89c4ff92SAndroid Build Coastguard Worker
375*89c4ff92SAndroid Build Coastguard Worker try
376*89c4ff92SAndroid Build Coastguard Worker {
377*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape shape = ParseTensorShape(ss);
378*89c4ff92SAndroid Build Coastguard Worker inputTensorShapes.push_back(shape);
379*89c4ff92SAndroid Build Coastguard Worker }
380*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::InvalidArgumentException& e)
381*89c4ff92SAndroid Build Coastguard Worker {
382*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
383*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker }
386*89c4ff92SAndroid Build Coastguard Worker }
387*89c4ff92SAndroid Build Coastguard Worker
388*89c4ff92SAndroid Build Coastguard Worker ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
389*89c4ff92SAndroid Build Coastguard Worker
390*89c4ff92SAndroid Build Coastguard Worker try
391*89c4ff92SAndroid Build Coastguard Worker {
392*89c4ff92SAndroid Build Coastguard Worker if (modelFormat.find("onnx") != std::string::npos)
393*89c4ff92SAndroid Build Coastguard Worker {
394*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
395*89c4ff92SAndroid Build Coastguard Worker if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
396*89c4ff92SAndroid Build Coastguard Worker {
397*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Failed to load model from file";
398*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
399*89c4ff92SAndroid Build Coastguard Worker }
400*89c4ff92SAndroid Build Coastguard Worker #else
401*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
402*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
403*89c4ff92SAndroid Build Coastguard Worker #endif
404*89c4ff92SAndroid Build Coastguard Worker }
405*89c4ff92SAndroid Build Coastguard Worker else if (modelFormat.find("tflite") != std::string::npos)
406*89c4ff92SAndroid Build Coastguard Worker {
407*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
408*89c4ff92SAndroid Build Coastguard Worker if (!isModelBinary)
409*89c4ff92SAndroid Build Coastguard Worker {
410*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
411*89c4ff92SAndroid Build Coastguard Worker for tflite files";
412*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
413*89c4ff92SAndroid Build Coastguard Worker }
414*89c4ff92SAndroid Build Coastguard Worker
415*89c4ff92SAndroid Build Coastguard Worker if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
416*89c4ff92SAndroid Build Coastguard Worker {
417*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Failed to load model from file";
418*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
419*89c4ff92SAndroid Build Coastguard Worker }
420*89c4ff92SAndroid Build Coastguard Worker #else
421*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
422*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
423*89c4ff92SAndroid Build Coastguard Worker #endif
424*89c4ff92SAndroid Build Coastguard Worker }
425*89c4ff92SAndroid Build Coastguard Worker else
426*89c4ff92SAndroid Build Coastguard Worker {
427*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
428*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
429*89c4ff92SAndroid Build Coastguard Worker }
430*89c4ff92SAndroid Build Coastguard Worker }
431*89c4ff92SAndroid Build Coastguard Worker catch(armnn::Exception& e)
432*89c4ff92SAndroid Build Coastguard Worker {
433*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
434*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
435*89c4ff92SAndroid Build Coastguard Worker }
436*89c4ff92SAndroid Build Coastguard Worker
437*89c4ff92SAndroid Build Coastguard Worker if (!converter.Serialize())
438*89c4ff92SAndroid Build Coastguard Worker {
439*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Failed to serialize model";
440*89c4ff92SAndroid Build Coastguard Worker return EXIT_FAILURE;
441*89c4ff92SAndroid Build Coastguard Worker }
442*89c4ff92SAndroid Build Coastguard Worker
443*89c4ff92SAndroid Build Coastguard Worker return EXIT_SUCCESS;
444*89c4ff92SAndroid Build Coastguard Worker }
445