xref: /aosp_15_r20/external/armnn/src/armnnConverter/ArmnnConverter.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnOnnxParser/IOnnxParser.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #endif
10*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_SERIALIZER)
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnSerializer/ISerializer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #endif
13*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
14*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #endif
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <HeapProfiling.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/StringUtils.hpp>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker /*
22*89c4ff92SAndroid Build Coastguard Worker  * Historically we use the ',' character to separate dimensions in a tensor shape. However, cxxopts will read this
23*89c4ff92SAndroid Build Coastguard Worker  * an an array of values which is fine until we have multiple tensors specified. This lumps the values of all shapes
24*89c4ff92SAndroid Build Coastguard Worker  * together in a single array and we cannot break it up again. We'll change the vector delimiter to a '.'. We do this
25*89c4ff92SAndroid Build Coastguard Worker  * as close as possible to the usage of cxxopts to avoid polluting other possible uses.
26*89c4ff92SAndroid Build Coastguard Worker  */
27*89c4ff92SAndroid Build Coastguard Worker #define CXXOPTS_VECTOR_DELIMITER '.'
28*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker #include <cstdlib>
33*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
34*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker namespace
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker 
ParseTensorShape(std::istream & stream)39*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape ParseTensorShape(std::istream& stream)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> result;
42*89c4ff92SAndroid Build Coastguard Worker     std::string line;
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     while (std::getline(stream, line))
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, ",");
47*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& token : tokens)
48*89c4ff92SAndroid Build Coastguard Worker         {
49*89c4ff92SAndroid Build Coastguard Worker             if (!token.empty())
50*89c4ff92SAndroid Build Coastguard Worker             {
51*89c4ff92SAndroid Build Coastguard Worker                 try
52*89c4ff92SAndroid Build Coastguard Worker                 {
53*89c4ff92SAndroid Build Coastguard Worker                     result.push_back(armnn::numeric_cast<unsigned int>(std::stoi((token))));
54*89c4ff92SAndroid Build Coastguard Worker                 }
55*89c4ff92SAndroid Build Coastguard Worker                 catch (const std::exception&)
56*89c4ff92SAndroid Build Coastguard Worker                 {
57*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
58*89c4ff92SAndroid Build Coastguard Worker                 }
59*89c4ff92SAndroid Build Coastguard Worker             }
60*89c4ff92SAndroid Build Coastguard Worker         }
61*89c4ff92SAndroid Build Coastguard Worker     }
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     return armnn::TensorShape(armnn::numeric_cast<unsigned int>(result.size()), result.data());
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker 
ParseCommandLineArgs(int argc,char * argv[],std::string & modelFormat,std::string & modelPath,std::vector<std::string> & inputNames,std::vector<std::string> & inputTensorShapeStrs,std::vector<std::string> & outputNames,std::string & outputPath,bool & isModelBinary)66*89c4ff92SAndroid Build Coastguard Worker int ParseCommandLineArgs(int argc, char* argv[],
67*89c4ff92SAndroid Build Coastguard Worker                          std::string& modelFormat,
68*89c4ff92SAndroid Build Coastguard Worker                          std::string& modelPath,
69*89c4ff92SAndroid Build Coastguard Worker                          std::vector<std::string>& inputNames,
70*89c4ff92SAndroid Build Coastguard Worker                          std::vector<std::string>& inputTensorShapeStrs,
71*89c4ff92SAndroid Build Coastguard Worker                          std::vector<std::string>& outputNames,
72*89c4ff92SAndroid Build Coastguard Worker                          std::string& outputPath, bool& isModelBinary)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker     cxxopts::Options options("ArmNNConverter", "Convert a neural network model from provided file to ArmNN format.");
75*89c4ff92SAndroid Build Coastguard Worker     try
76*89c4ff92SAndroid Build Coastguard Worker     {
77*89c4ff92SAndroid Build Coastguard Worker         std::string modelFormatDescription("Format of the model file");
78*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
79*89c4ff92SAndroid Build Coastguard Worker         modelFormatDescription += ", onnx-binary, onnx-text";
80*89c4ff92SAndroid Build Coastguard Worker #endif
81*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_PARSER)
82*89c4ff92SAndroid Build Coastguard Worker         modelFormatDescription += ", tensorflow-binary, tensorflow-text";
83*89c4ff92SAndroid Build Coastguard Worker #endif
84*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
85*89c4ff92SAndroid Build Coastguard Worker         modelFormatDescription += ", tflite-binary";
86*89c4ff92SAndroid Build Coastguard Worker #endif
87*89c4ff92SAndroid Build Coastguard Worker         modelFormatDescription += ".";
88*89c4ff92SAndroid Build Coastguard Worker         options.add_options()
89*89c4ff92SAndroid Build Coastguard Worker             ("help", "Display usage information")
90*89c4ff92SAndroid Build Coastguard Worker             ("f,model-format", modelFormatDescription, cxxopts::value<std::string>(modelFormat))
91*89c4ff92SAndroid Build Coastguard Worker             ("m,model-path", "Path to model file.", cxxopts::value<std::string>(modelPath))
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker             ("i,input-name", "Identifier of the input tensors in the network. "
94*89c4ff92SAndroid Build Coastguard Worker                              "Each input must be specified separately.",
95*89c4ff92SAndroid Build Coastguard Worker                              cxxopts::value<std::vector<std::string>>(inputNames))
96*89c4ff92SAndroid Build Coastguard Worker             ("s,input-tensor-shape",
97*89c4ff92SAndroid Build Coastguard Worker                              "The shape of the input tensor in the network as a flat array of integers, "
98*89c4ff92SAndroid Build Coastguard Worker                              "separated by comma. Each input shape must be specified separately after the input name. "
99*89c4ff92SAndroid Build Coastguard Worker                              "This parameter is optional, depending on the network.",
100*89c4ff92SAndroid Build Coastguard Worker                              cxxopts::value<std::vector<std::string>>(inputTensorShapeStrs))
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker             ("o,output-name", "Identifier of the output tensor in the network.",
103*89c4ff92SAndroid Build Coastguard Worker                               cxxopts::value<std::vector<std::string>>(outputNames))
104*89c4ff92SAndroid Build Coastguard Worker             ("p,output-path",
105*89c4ff92SAndroid Build Coastguard Worker                          "Path to serialize the network to.", cxxopts::value<std::string>(outputPath));
106*89c4ff92SAndroid Build Coastguard Worker     }
107*89c4ff92SAndroid Build Coastguard Worker     catch (const std::exception& e)
108*89c4ff92SAndroid Build Coastguard Worker     {
109*89c4ff92SAndroid Build Coastguard Worker         std::cerr << e.what() << std::endl << options.help() << std::endl;
110*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
111*89c4ff92SAndroid Build Coastguard Worker     }
112*89c4ff92SAndroid Build Coastguard Worker     try
113*89c4ff92SAndroid Build Coastguard Worker     {
114*89c4ff92SAndroid Build Coastguard Worker         cxxopts::ParseResult result = options.parse(argc, argv);
115*89c4ff92SAndroid Build Coastguard Worker         if (result.count("help"))
116*89c4ff92SAndroid Build Coastguard Worker         {
117*89c4ff92SAndroid Build Coastguard Worker             std::cerr << options.help()  << std::endl;
118*89c4ff92SAndroid Build Coastguard Worker             return EXIT_SUCCESS;
119*89c4ff92SAndroid Build Coastguard Worker         }
120*89c4ff92SAndroid Build Coastguard Worker         // Check for mandatory single options.
121*89c4ff92SAndroid Build Coastguard Worker         std::string mandatorySingleParameters[] = { "model-format", "model-path", "output-name", "output-path" };
122*89c4ff92SAndroid Build Coastguard Worker         bool somethingsMissing = false;
123*89c4ff92SAndroid Build Coastguard Worker         for (auto param : mandatorySingleParameters)
124*89c4ff92SAndroid Build Coastguard Worker         {
125*89c4ff92SAndroid Build Coastguard Worker             if (result.count(param) != 1)
126*89c4ff92SAndroid Build Coastguard Worker             {
127*89c4ff92SAndroid Build Coastguard Worker                 std::cerr << "Parameter \'--" << param << "\' is required but missing." << std::endl;
128*89c4ff92SAndroid Build Coastguard Worker                 somethingsMissing = true;
129*89c4ff92SAndroid Build Coastguard Worker             }
130*89c4ff92SAndroid Build Coastguard Worker         }
131*89c4ff92SAndroid Build Coastguard Worker         // Check at least one "input-name" option.
132*89c4ff92SAndroid Build Coastguard Worker         if (result.count("input-name") == 0)
133*89c4ff92SAndroid Build Coastguard Worker         {
134*89c4ff92SAndroid Build Coastguard Worker             std::cerr << "Parameter \'--" << "input-name" << "\' must be specified at least once." << std::endl;
135*89c4ff92SAndroid Build Coastguard Worker             somethingsMissing = true;
136*89c4ff92SAndroid Build Coastguard Worker         }
137*89c4ff92SAndroid Build Coastguard Worker         // If input-tensor-shape is specified then there must be a 1:1 match with input-name.
138*89c4ff92SAndroid Build Coastguard Worker         if (result.count("input-tensor-shape") > 0)
139*89c4ff92SAndroid Build Coastguard Worker         {
140*89c4ff92SAndroid Build Coastguard Worker             if (result.count("input-tensor-shape") != result.count("input-name"))
141*89c4ff92SAndroid Build Coastguard Worker             {
142*89c4ff92SAndroid Build Coastguard Worker                 std::cerr << "When specifying \'input-tensor-shape\' a matching number of \'input-name\' parameters "
143*89c4ff92SAndroid Build Coastguard Worker                              "must be specified." << std::endl;
144*89c4ff92SAndroid Build Coastguard Worker                 somethingsMissing = true;
145*89c4ff92SAndroid Build Coastguard Worker             }
146*89c4ff92SAndroid Build Coastguard Worker         }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker         if (somethingsMissing)
149*89c4ff92SAndroid Build Coastguard Worker         {
150*89c4ff92SAndroid Build Coastguard Worker             std::cerr << options.help()  << std::endl;
151*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
152*89c4ff92SAndroid Build Coastguard Worker         }
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker     catch (const cxxopts::OptionException& e)
155*89c4ff92SAndroid Build Coastguard Worker     {
156*89c4ff92SAndroid Build Coastguard Worker         std::cerr << e.what() << std::endl << std::endl;
157*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
158*89c4ff92SAndroid Build Coastguard Worker     }
159*89c4ff92SAndroid Build Coastguard Worker 
160*89c4ff92SAndroid Build Coastguard Worker     if (modelFormat.find("bin") != std::string::npos)
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         isModelBinary = true;
163*89c4ff92SAndroid Build Coastguard Worker     }
164*89c4ff92SAndroid Build Coastguard Worker     else if (modelFormat.find("text") != std::string::npos)
165*89c4ff92SAndroid Build Coastguard Worker     {
166*89c4ff92SAndroid Build Coastguard Worker         isModelBinary = false;
167*89c4ff92SAndroid Build Coastguard Worker     }
168*89c4ff92SAndroid Build Coastguard Worker     else
169*89c4ff92SAndroid Build Coastguard Worker     {
170*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
171*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
172*89c4ff92SAndroid Build Coastguard Worker     }
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker     return EXIT_SUCCESS;
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker template<typename T>
178*89c4ff92SAndroid Build Coastguard Worker struct ParserType
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker     typedef T parserType;
181*89c4ff92SAndroid Build Coastguard Worker };
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker class ArmnnConverter
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker public:
ArmnnConverter(const std::string & modelPath,const std::vector<std::string> & inputNames,const std::vector<armnn::TensorShape> & inputShapes,const std::vector<std::string> & outputNames,const std::string & outputPath,bool isModelBinary)186*89c4ff92SAndroid Build Coastguard Worker     ArmnnConverter(const std::string& modelPath,
187*89c4ff92SAndroid Build Coastguard Worker                    const std::vector<std::string>& inputNames,
188*89c4ff92SAndroid Build Coastguard Worker                    const std::vector<armnn::TensorShape>& inputShapes,
189*89c4ff92SAndroid Build Coastguard Worker                    const std::vector<std::string>& outputNames,
190*89c4ff92SAndroid Build Coastguard Worker                    const std::string& outputPath,
191*89c4ff92SAndroid Build Coastguard Worker                    bool isModelBinary)
192*89c4ff92SAndroid Build Coastguard Worker     : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
193*89c4ff92SAndroid Build Coastguard Worker     m_ModelPath(modelPath),
194*89c4ff92SAndroid Build Coastguard Worker     m_InputNames(inputNames),
195*89c4ff92SAndroid Build Coastguard Worker     m_InputShapes(inputShapes),
196*89c4ff92SAndroid Build Coastguard Worker     m_OutputNames(outputNames),
197*89c4ff92SAndroid Build Coastguard Worker     m_OutputPath(outputPath),
198*89c4ff92SAndroid Build Coastguard Worker     m_IsModelBinary(isModelBinary) {}
199*89c4ff92SAndroid Build Coastguard Worker 
Serialize()200*89c4ff92SAndroid Build Coastguard Worker     bool Serialize()
201*89c4ff92SAndroid Build Coastguard Worker     {
202*89c4ff92SAndroid Build Coastguard Worker         if (m_NetworkPtr.get() == nullptr)
203*89c4ff92SAndroid Build Coastguard Worker         {
204*89c4ff92SAndroid Build Coastguard Worker             return false;
205*89c4ff92SAndroid Build Coastguard Worker         }
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker         auto serializer(armnnSerializer::ISerializer::Create());
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker         serializer->Serialize(*m_NetworkPtr);
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker         std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker         bool retVal = serializer->SaveSerializedToStream(file);
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker         return retVal;
216*89c4ff92SAndroid Build Coastguard Worker     }
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     template <typename IParser>
CreateNetwork()219*89c4ff92SAndroid Build Coastguard Worker     bool CreateNetwork ()
220*89c4ff92SAndroid Build Coastguard Worker     {
221*89c4ff92SAndroid Build Coastguard Worker         return CreateNetwork (ParserType<IParser>());
222*89c4ff92SAndroid Build Coastguard Worker     }
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker private:
225*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr              m_NetworkPtr;
226*89c4ff92SAndroid Build Coastguard Worker     std::string                     m_ModelPath;
227*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string>        m_InputNames;
228*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorShape> m_InputShapes;
229*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string>        m_OutputNames;
230*89c4ff92SAndroid Build Coastguard Worker     std::string                     m_OutputPath;
231*89c4ff92SAndroid Build Coastguard Worker     bool                            m_IsModelBinary;
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker     template <typename IParser>
CreateNetwork(ParserType<IParser>)234*89c4ff92SAndroid Build Coastguard Worker     bool CreateNetwork (ParserType<IParser>)
235*89c4ff92SAndroid Build Coastguard Worker     {
236*89c4ff92SAndroid Build Coastguard Worker         // Create a network from a file on disk
237*89c4ff92SAndroid Build Coastguard Worker         auto parser(IParser::Create());
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker         std::map<std::string, armnn::TensorShape> inputShapes;
240*89c4ff92SAndroid Build Coastguard Worker         if (!m_InputShapes.empty())
241*89c4ff92SAndroid Build Coastguard Worker         {
242*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputShapes   = m_InputShapes.size();
243*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputBindings = m_InputNames.size();
244*89c4ff92SAndroid Build Coastguard Worker             if (numInputShapes < numInputBindings)
245*89c4ff92SAndroid Build Coastguard Worker             {
246*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception(fmt::format(
247*89c4ff92SAndroid Build Coastguard Worker                    "Not every input has its tensor shape specified: expected={0}, got={1}",
248*89c4ff92SAndroid Build Coastguard Worker                    numInputBindings, numInputShapes));
249*89c4ff92SAndroid Build Coastguard Worker             }
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker             for (size_t i = 0; i < numInputShapes; i++)
252*89c4ff92SAndroid Build Coastguard Worker             {
253*89c4ff92SAndroid Build Coastguard Worker                 inputShapes[m_InputNames[i]] = m_InputShapes[i];
254*89c4ff92SAndroid Build Coastguard Worker             }
255*89c4ff92SAndroid Build Coastguard Worker         }
256*89c4ff92SAndroid Build Coastguard Worker 
257*89c4ff92SAndroid Build Coastguard Worker         {
258*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
259*89c4ff92SAndroid Build Coastguard Worker             m_NetworkPtr = (m_IsModelBinary ?
260*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
261*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
262*89c4ff92SAndroid Build Coastguard Worker         }
263*89c4ff92SAndroid Build Coastguard Worker 
264*89c4ff92SAndroid Build Coastguard Worker         return m_NetworkPtr.get() != nullptr;
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
CreateNetwork(ParserType<armnnTfLiteParser::ITfLiteParser>)268*89c4ff92SAndroid Build Coastguard Worker     bool CreateNetwork (ParserType<armnnTfLiteParser::ITfLiteParser>)
269*89c4ff92SAndroid Build Coastguard Worker     {
270*89c4ff92SAndroid Build Coastguard Worker         // Create a network from a file on disk
271*89c4ff92SAndroid Build Coastguard Worker         auto parser(armnnTfLiteParser::ITfLiteParser::Create());
272*89c4ff92SAndroid Build Coastguard Worker 
273*89c4ff92SAndroid Build Coastguard Worker         if (!m_InputShapes.empty())
274*89c4ff92SAndroid Build Coastguard Worker         {
275*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputShapes   = m_InputShapes.size();
276*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputBindings = m_InputNames.size();
277*89c4ff92SAndroid Build Coastguard Worker             if (numInputShapes < numInputBindings)
278*89c4ff92SAndroid Build Coastguard Worker             {
279*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception(fmt::format(
280*89c4ff92SAndroid Build Coastguard Worker                    "Not every input has its tensor shape specified: expected={0}, got={1}",
281*89c4ff92SAndroid Build Coastguard Worker                    numInputBindings, numInputShapes));
282*89c4ff92SAndroid Build Coastguard Worker             }
283*89c4ff92SAndroid Build Coastguard Worker         }
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker         {
286*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
287*89c4ff92SAndroid Build Coastguard Worker             m_NetworkPtr = parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str());
288*89c4ff92SAndroid Build Coastguard Worker         }
289*89c4ff92SAndroid Build Coastguard Worker 
290*89c4ff92SAndroid Build Coastguard Worker         return m_NetworkPtr.get() != nullptr;
291*89c4ff92SAndroid Build Coastguard Worker     }
292*89c4ff92SAndroid Build Coastguard Worker #endif
293*89c4ff92SAndroid Build Coastguard Worker 
294*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
CreateNetwork(ParserType<armnnOnnxParser::IOnnxParser>)295*89c4ff92SAndroid Build Coastguard Worker     bool CreateNetwork (ParserType<armnnOnnxParser::IOnnxParser>)
296*89c4ff92SAndroid Build Coastguard Worker     {
297*89c4ff92SAndroid Build Coastguard Worker         // Create a network from a file on disk
298*89c4ff92SAndroid Build Coastguard Worker         auto parser(armnnOnnxParser::IOnnxParser::Create());
299*89c4ff92SAndroid Build Coastguard Worker 
300*89c4ff92SAndroid Build Coastguard Worker         if (!m_InputShapes.empty())
301*89c4ff92SAndroid Build Coastguard Worker         {
302*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputShapes   = m_InputShapes.size();
303*89c4ff92SAndroid Build Coastguard Worker             const size_t numInputBindings = m_InputNames.size();
304*89c4ff92SAndroid Build Coastguard Worker             if (numInputShapes < numInputBindings)
305*89c4ff92SAndroid Build Coastguard Worker             {
306*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception(fmt::format(
307*89c4ff92SAndroid Build Coastguard Worker                    "Not every input has its tensor shape specified: expected={0}, got={1}",
308*89c4ff92SAndroid Build Coastguard Worker                    numInputBindings, numInputShapes));
309*89c4ff92SAndroid Build Coastguard Worker             }
310*89c4ff92SAndroid Build Coastguard Worker         }
311*89c4ff92SAndroid Build Coastguard Worker 
312*89c4ff92SAndroid Build Coastguard Worker         {
313*89c4ff92SAndroid Build Coastguard Worker             ARMNN_SCOPED_HEAP_PROFILING("Parsing");
314*89c4ff92SAndroid Build Coastguard Worker             m_NetworkPtr = (m_IsModelBinary ?
315*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str()) :
316*89c4ff92SAndroid Build Coastguard Worker                 parser->CreateNetworkFromTextFile(m_ModelPath.c_str()));
317*89c4ff92SAndroid Build Coastguard Worker         }
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker         return m_NetworkPtr.get() != nullptr;
320*89c4ff92SAndroid Build Coastguard Worker     }
321*89c4ff92SAndroid Build Coastguard Worker #endif
322*89c4ff92SAndroid Build Coastguard Worker 
323*89c4ff92SAndroid Build Coastguard Worker };
324*89c4ff92SAndroid Build Coastguard Worker 
325*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
326*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])327*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
328*89c4ff92SAndroid Build Coastguard Worker {
329*89c4ff92SAndroid Build Coastguard Worker 
330*89c4ff92SAndroid Build Coastguard Worker #if (!defined(ARMNN_ONNX_PARSER) \
331*89c4ff92SAndroid Build Coastguard Worker        && !defined(ARMNN_TF_PARSER)   \
332*89c4ff92SAndroid Build Coastguard Worker        && !defined(ARMNN_TF_LITE_PARSER))
333*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(fatal) << "Not built with any of the supported parsers Onnx, Tensorflow, or TfLite.";
334*89c4ff92SAndroid Build Coastguard Worker     return EXIT_FAILURE;
335*89c4ff92SAndroid Build Coastguard Worker #endif
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_SERIALIZER)
338*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(fatal) << "Not built with Serializer support.";
339*89c4ff92SAndroid Build Coastguard Worker     return EXIT_FAILURE;
340*89c4ff92SAndroid Build Coastguard Worker #endif
341*89c4ff92SAndroid Build Coastguard Worker 
342*89c4ff92SAndroid Build Coastguard Worker #ifdef NDEBUG
343*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSeverity level = armnn::LogSeverity::Info;
344*89c4ff92SAndroid Build Coastguard Worker #else
345*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSeverity level = armnn::LogSeverity::Debug;
346*89c4ff92SAndroid Build Coastguard Worker #endif
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker     armnn::ConfigureLogging(true, true, level);
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker     std::string modelFormat;
351*89c4ff92SAndroid Build Coastguard Worker     std::string modelPath;
352*89c4ff92SAndroid Build Coastguard Worker 
353*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> inputNames;
354*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> inputTensorShapeStrs;
355*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorShape> inputTensorShapes;
356*89c4ff92SAndroid Build Coastguard Worker 
357*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> outputNames;
358*89c4ff92SAndroid Build Coastguard Worker     std::string outputPath;
359*89c4ff92SAndroid Build Coastguard Worker 
360*89c4ff92SAndroid Build Coastguard Worker     bool isModelBinary = true;
361*89c4ff92SAndroid Build Coastguard Worker 
362*89c4ff92SAndroid Build Coastguard Worker     if (ParseCommandLineArgs(
363*89c4ff92SAndroid Build Coastguard Worker         argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
364*89c4ff92SAndroid Build Coastguard Worker         != EXIT_SUCCESS)
365*89c4ff92SAndroid Build Coastguard Worker     {
366*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
367*89c4ff92SAndroid Build Coastguard Worker     }
368*89c4ff92SAndroid Build Coastguard Worker 
369*89c4ff92SAndroid Build Coastguard Worker     for (const std::string& shapeStr : inputTensorShapeStrs)
370*89c4ff92SAndroid Build Coastguard Worker     {
371*89c4ff92SAndroid Build Coastguard Worker         if (!shapeStr.empty())
372*89c4ff92SAndroid Build Coastguard Worker         {
373*89c4ff92SAndroid Build Coastguard Worker             std::stringstream ss(shapeStr);
374*89c4ff92SAndroid Build Coastguard Worker 
375*89c4ff92SAndroid Build Coastguard Worker             try
376*89c4ff92SAndroid Build Coastguard Worker             {
377*89c4ff92SAndroid Build Coastguard Worker                 armnn::TensorShape shape = ParseTensorShape(ss);
378*89c4ff92SAndroid Build Coastguard Worker                 inputTensorShapes.push_back(shape);
379*89c4ff92SAndroid Build Coastguard Worker             }
380*89c4ff92SAndroid Build Coastguard Worker             catch (const armnn::InvalidArgumentException& e)
381*89c4ff92SAndroid Build Coastguard Worker             {
382*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Cannot create tensor shape: " << e.what();
383*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
384*89c4ff92SAndroid Build Coastguard Worker             }
385*89c4ff92SAndroid Build Coastguard Worker         }
386*89c4ff92SAndroid Build Coastguard Worker     }
387*89c4ff92SAndroid Build Coastguard Worker 
388*89c4ff92SAndroid Build Coastguard Worker     ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
389*89c4ff92SAndroid Build Coastguard Worker 
390*89c4ff92SAndroid Build Coastguard Worker     try
391*89c4ff92SAndroid Build Coastguard Worker     {
392*89c4ff92SAndroid Build Coastguard Worker         if (modelFormat.find("onnx") != std::string::npos)
393*89c4ff92SAndroid Build Coastguard Worker         {
394*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_ONNX_PARSER)
395*89c4ff92SAndroid Build Coastguard Worker             if (!converter.CreateNetwork<armnnOnnxParser::IOnnxParser>())
396*89c4ff92SAndroid Build Coastguard Worker             {
397*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Failed to load model from file";
398*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
399*89c4ff92SAndroid Build Coastguard Worker             }
400*89c4ff92SAndroid Build Coastguard Worker #else
401*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
402*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
403*89c4ff92SAndroid Build Coastguard Worker #endif
404*89c4ff92SAndroid Build Coastguard Worker         }
405*89c4ff92SAndroid Build Coastguard Worker         else if (modelFormat.find("tflite") != std::string::npos)
406*89c4ff92SAndroid Build Coastguard Worker         {
407*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_TF_LITE_PARSER)
408*89c4ff92SAndroid Build Coastguard Worker             if (!isModelBinary)
409*89c4ff92SAndroid Build Coastguard Worker             {
410*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
411*89c4ff92SAndroid Build Coastguard Worker                   for tflite files";
412*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
413*89c4ff92SAndroid Build Coastguard Worker             }
414*89c4ff92SAndroid Build Coastguard Worker 
415*89c4ff92SAndroid Build Coastguard Worker             if (!converter.CreateNetwork<armnnTfLiteParser::ITfLiteParser>())
416*89c4ff92SAndroid Build Coastguard Worker             {
417*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(fatal) << "Failed to load model from file";
418*89c4ff92SAndroid Build Coastguard Worker                 return EXIT_FAILURE;
419*89c4ff92SAndroid Build Coastguard Worker             }
420*89c4ff92SAndroid Build Coastguard Worker #else
421*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Not built with TfLite parser support.";
422*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
423*89c4ff92SAndroid Build Coastguard Worker #endif
424*89c4ff92SAndroid Build Coastguard Worker         }
425*89c4ff92SAndroid Build Coastguard Worker         else
426*89c4ff92SAndroid Build Coastguard Worker         {
427*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat << "'";
428*89c4ff92SAndroid Build Coastguard Worker             return EXIT_FAILURE;
429*89c4ff92SAndroid Build Coastguard Worker         }
430*89c4ff92SAndroid Build Coastguard Worker     }
431*89c4ff92SAndroid Build Coastguard Worker     catch(armnn::Exception& e)
432*89c4ff92SAndroid Build Coastguard Worker     {
433*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Failed to load model from file: " << e.what();
434*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
435*89c4ff92SAndroid Build Coastguard Worker     }
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker     if (!converter.Serialize())
438*89c4ff92SAndroid Build Coastguard Worker     {
439*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Failed to serialize model";
440*89c4ff92SAndroid Build Coastguard Worker         return EXIT_FAILURE;
441*89c4ff92SAndroid Build Coastguard Worker     }
442*89c4ff92SAndroid Build Coastguard Worker 
443*89c4ff92SAndroid Build Coastguard Worker     return EXIT_SUCCESS;
444*89c4ff92SAndroid Build Coastguard Worker }
445