xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/TfLiteParser.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/INetwork.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/ITfLiteParser.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Types.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <functional>
14*89c4ff92SAndroid Build Coastguard Worker #include <unordered_map>
15*89c4ff92SAndroid Build Coastguard Worker #include <vector>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
20*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_3
21*89c4ff92SAndroid Build Coastguard Worker #endif
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker namespace armnnTfLiteParser
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker class TfLiteParserImpl
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker public:
29*89c4ff92SAndroid Build Coastguard Worker     // Shorthands for TfLite types
30*89c4ff92SAndroid Build Coastguard Worker     using ModelPtr = std::unique_ptr<tflite::ModelT>;
31*89c4ff92SAndroid Build Coastguard Worker     using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
32*89c4ff92SAndroid Build Coastguard Worker     using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
33*89c4ff92SAndroid Build Coastguard Worker     using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
34*89c4ff92SAndroid Build Coastguard Worker     using TensorPtr = std::unique_ptr<tflite::TensorT>;
35*89c4ff92SAndroid Build Coastguard Worker     using TensorRawPtr = const tflite::TensorT *;
36*89c4ff92SAndroid Build Coastguard Worker     using TensorRawPtrVector = std::vector<TensorRawPtr>;
37*89c4ff92SAndroid Build Coastguard Worker     using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
38*89c4ff92SAndroid Build Coastguard Worker     using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
39*89c4ff92SAndroid Build Coastguard Worker     using BufferPtr = std::unique_ptr<tflite::BufferT>;
40*89c4ff92SAndroid Build Coastguard Worker     using BufferRawPtr = const tflite::BufferT *;
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker public:
43*89c4ff92SAndroid Build Coastguard Worker     /// Create the network from a flatbuffers binary file on disk
44*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     /// Create the network from a flatbuffers binary
47*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve binding info (layer id and tensor info) for the network input identified by
51*89c4ff92SAndroid Build Coastguard Worker     /// the given layer name and subgraph id
52*89c4ff92SAndroid Build Coastguard Worker     BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
53*89c4ff92SAndroid Build Coastguard Worker                                                 const std::string& name) const;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve binding info (layer id and tensor info) for the network output identified by
56*89c4ff92SAndroid Build Coastguard Worker     /// the given layer name and subgraph id
57*89c4ff92SAndroid Build Coastguard Worker     BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
58*89c4ff92SAndroid Build Coastguard Worker                                                          const std::string& name) const;
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     /// Return the number of subgraphs in the parsed model
61*89c4ff92SAndroid Build Coastguard Worker     size_t GetSubgraphCount() const;
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     /// Return the input tensor names for a given subgraph
64*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     /// Return the output tensor names for a given subgraph
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
70*89c4ff92SAndroid Build Coastguard Worker     ~TfLiteParserImpl() = default;
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker public:
73*89c4ff92SAndroid Build Coastguard Worker     // testable helpers
74*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model);
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     static ModelPtr LoadModelFromFile(const char* fileName);
79*89c4ff92SAndroid Build Coastguard Worker     static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len);
80*89c4ff92SAndroid Build Coastguard Worker     static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
81*89c4ff92SAndroid Build Coastguard Worker     static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
82*89c4ff92SAndroid Build Coastguard Worker     static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex);
83*89c4ff92SAndroid Build Coastguard Worker     static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex);
84*89c4ff92SAndroid Build Coastguard Worker     static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
85*89c4ff92SAndroid Build Coastguard Worker     static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
88*89c4ff92SAndroid Build Coastguard Worker     static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
89*89c4ff92SAndroid Build Coastguard Worker                                                   const armnn::TensorInfo& inputTensorInfo);
90*89c4ff92SAndroid Build Coastguard Worker     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
91*89c4ff92SAndroid Build Coastguard Worker                                                   const std::vector<int32_t>& targetDimsIn);
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve version in X.Y.Z form
94*89c4ff92SAndroid Build Coastguard Worker     static const std::string GetVersion();
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker private:
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     // No copying allowed until it is wanted and properly implemented
99*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl(const TfLiteParserImpl &) = delete;
100*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete;
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     /// Create the network from an already loaded flatbuffers model
103*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr CreateNetworkFromModel();
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     // signature for the parser functions
106*89c4ff92SAndroid Build Coastguard Worker     using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex);
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
109*89c4ff92SAndroid Build Coastguard Worker     void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     void ParseAbs(size_t subgraphIndex, size_t operatorIndex);
112*89c4ff92SAndroid Build Coastguard Worker     void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
113*89c4ff92SAndroid Build Coastguard Worker     void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
114*89c4ff92SAndroid Build Coastguard Worker     void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction);
115*89c4ff92SAndroid Build Coastguard Worker     void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
116*89c4ff92SAndroid Build Coastguard Worker     void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
117*89c4ff92SAndroid Build Coastguard Worker     void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
118*89c4ff92SAndroid Build Coastguard Worker     void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
119*89c4ff92SAndroid Build Coastguard Worker     void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
120*89c4ff92SAndroid Build Coastguard Worker     void ParseCast(size_t subgraphIndex, size_t operatorIndex);
121*89c4ff92SAndroid Build Coastguard Worker     void ParseCeil(size_t subgraphIndex, size_t operatorIndex);
122*89c4ff92SAndroid Build Coastguard Worker     void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
123*89c4ff92SAndroid Build Coastguard Worker     void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
124*89c4ff92SAndroid Build Coastguard Worker     void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
125*89c4ff92SAndroid Build Coastguard Worker     // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
126*89c4ff92SAndroid Build Coastguard Worker     #if defined(ARMNN_POST_TFLITE_2_4)
127*89c4ff92SAndroid Build Coastguard Worker     void ParseConv3D(size_t subgraphIndex, size_t operatorIndex);
128*89c4ff92SAndroid Build Coastguard Worker     #endif
129*89c4ff92SAndroid Build Coastguard Worker     void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
130*89c4ff92SAndroid Build Coastguard Worker     void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
131*89c4ff92SAndroid Build Coastguard Worker     void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
132*89c4ff92SAndroid Build Coastguard Worker     void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
133*89c4ff92SAndroid Build Coastguard Worker     void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
134*89c4ff92SAndroid Build Coastguard Worker     void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
135*89c4ff92SAndroid Build Coastguard Worker     void ParseElu(size_t subgraphIndex, size_t operatorIndex);
136*89c4ff92SAndroid Build Coastguard Worker     void ParseEqual(size_t subgraphIndex, size_t operatorIndex);
137*89c4ff92SAndroid Build Coastguard Worker     void ParseExp(size_t subgraphIndex, size_t operatorIndex);
138*89c4ff92SAndroid Build Coastguard Worker     void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
139*89c4ff92SAndroid Build Coastguard Worker     void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex);
140*89c4ff92SAndroid Build Coastguard Worker     void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
141*89c4ff92SAndroid Build Coastguard Worker     void ParseGather(size_t subgraphIndex, size_t operatorIndex);
142*89c4ff92SAndroid Build Coastguard Worker     void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex);
143*89c4ff92SAndroid Build Coastguard Worker     void ParseGreater(size_t subgraphIndex, size_t operatorIndex);
144*89c4ff92SAndroid Build Coastguard Worker     void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex);
145*89c4ff92SAndroid Build Coastguard Worker     void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
146*89c4ff92SAndroid Build Coastguard Worker     void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
147*89c4ff92SAndroid Build Coastguard Worker     void ParseLess(size_t subgraphIndex, size_t operatorIndex);
148*89c4ff92SAndroid Build Coastguard Worker     void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex);
149*89c4ff92SAndroid Build Coastguard Worker     void ParseLog(size_t subgraphIndex, size_t operatorIndex);
150*89c4ff92SAndroid Build Coastguard Worker     void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex);
151*89c4ff92SAndroid Build Coastguard Worker     void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex);
152*89c4ff92SAndroid Build Coastguard Worker     void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
153*89c4ff92SAndroid Build Coastguard Worker     void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex);
154*89c4ff92SAndroid Build Coastguard Worker     void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
155*89c4ff92SAndroid Build Coastguard Worker     void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
156*89c4ff92SAndroid Build Coastguard Worker     void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
157*89c4ff92SAndroid Build Coastguard Worker     void ParseMean(size_t subgraphIndex, size_t operatorIndex);
158*89c4ff92SAndroid Build Coastguard Worker     void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
159*89c4ff92SAndroid Build Coastguard Worker     void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex);
160*89c4ff92SAndroid Build Coastguard Worker     void ParseMul(size_t subgraphIndex, size_t operatorIndex);
161*89c4ff92SAndroid Build Coastguard Worker     void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
162*89c4ff92SAndroid Build Coastguard Worker     void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex);
163*89c4ff92SAndroid Build Coastguard Worker     void ParsePack(size_t subgraphIndex, size_t operatorIndex);
164*89c4ff92SAndroid Build Coastguard Worker     void ParsePad(size_t subgraphIndex, size_t operatorIndex);
165*89c4ff92SAndroid Build Coastguard Worker     void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
166*89c4ff92SAndroid Build Coastguard Worker     void ParsePrelu(size_t subgraphIndex, size_t operatorIndex);
167*89c4ff92SAndroid Build Coastguard Worker     void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
168*89c4ff92SAndroid Build Coastguard Worker     void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation);
169*89c4ff92SAndroid Build Coastguard Worker     void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex);
170*89c4ff92SAndroid Build Coastguard Worker     void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex);
171*89c4ff92SAndroid Build Coastguard Worker     void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex);
172*89c4ff92SAndroid Build Coastguard Worker     void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
173*89c4ff92SAndroid Build Coastguard Worker     void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
174*89c4ff92SAndroid Build Coastguard Worker     void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
175*89c4ff92SAndroid Build Coastguard Worker     void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
176*89c4ff92SAndroid Build Coastguard Worker     void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
177*89c4ff92SAndroid Build Coastguard Worker     void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
178*89c4ff92SAndroid Build Coastguard Worker     void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex);
179*89c4ff92SAndroid Build Coastguard Worker     void ParseShape(size_t subgraphIndex, size_t operatorIndex);
180*89c4ff92SAndroid Build Coastguard Worker     void ParseSin(size_t subgraphIndex, size_t operatorIndex);
181*89c4ff92SAndroid Build Coastguard Worker     void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
182*89c4ff92SAndroid Build Coastguard Worker     void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
183*89c4ff92SAndroid Build Coastguard Worker     void ParseSqrt(size_t subgraphIndex, size_t operatorIndex);
184*89c4ff92SAndroid Build Coastguard Worker     void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
185*89c4ff92SAndroid Build Coastguard Worker     void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex);
186*89c4ff92SAndroid Build Coastguard Worker     void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
187*89c4ff92SAndroid Build Coastguard Worker     void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
188*89c4ff92SAndroid Build Coastguard Worker     void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
189*89c4ff92SAndroid Build Coastguard Worker     void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
190*89c4ff92SAndroid Build Coastguard Worker     void ParseSub(size_t subgraphIndex, size_t operatorIndex);
191*89c4ff92SAndroid Build Coastguard Worker     void ParseSum(size_t subgraphIndex, size_t operatorIndex);
192*89c4ff92SAndroid Build Coastguard Worker     void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
193*89c4ff92SAndroid Build Coastguard Worker     void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
194*89c4ff92SAndroid Build Coastguard Worker     void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
195*89c4ff92SAndroid Build Coastguard Worker     void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
196*89c4ff92SAndroid Build Coastguard Worker     void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
199*89c4ff92SAndroid Build Coastguard Worker     void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
200*89c4ff92SAndroid Build Coastguard Worker     void RegisterInputSlots(size_t subgraphIndex,
201*89c4ff92SAndroid Build Coastguard Worker                             size_t operatorIndex,
202*89c4ff92SAndroid Build Coastguard Worker                             armnn::IConnectableLayer* layer,
203*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<unsigned int>& tensorIndexes,
204*89c4ff92SAndroid Build Coastguard Worker                             unsigned int startingSlotIndex = 0);
205*89c4ff92SAndroid Build Coastguard Worker     void RegisterOutputSlots(size_t subgraphIndex,
206*89c4ff92SAndroid Build Coastguard Worker                              size_t operatorIndex,
207*89c4ff92SAndroid Build Coastguard Worker                              armnn::IConnectableLayer* layer,
208*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<unsigned int>& tensorIndexes);
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker     void SetupInputLayerTensorInfos(size_t subgraphIndex);
211*89c4ff92SAndroid Build Coastguard Worker     void SetupConstantLayerTensorInfos(size_t subgraphIndex);
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker     void SetupInputLayers(size_t subgraphIndex);
214*89c4ff92SAndroid Build Coastguard Worker     void SetupOutputLayers(size_t subgraphIndex);
215*89c4ff92SAndroid Build Coastguard Worker     void SetupConstantLayers(size_t subgraphIndex);
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker     void ResetParser();
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker     void AddBroadcastReshapeLayer(size_t subgraphIndex,
220*89c4ff92SAndroid Build Coastguard Worker                                   size_t operatorIndex,
221*89c4ff92SAndroid Build Coastguard Worker                                   armnn::IConnectableLayer* layer);
222*89c4ff92SAndroid Build Coastguard Worker 
223*89c4ff92SAndroid Build Coastguard Worker     /// Attach an reshape layer to the one passed as a parameter
224*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* AddReshapeLayer(armnn::IConnectableLayer* layer,
225*89c4ff92SAndroid Build Coastguard Worker                                               unsigned int outputSlot,
226*89c4ff92SAndroid Build Coastguard Worker                                               std::string reshapeLayerName,
227*89c4ff92SAndroid Build Coastguard Worker                                               armnn::TensorInfo outputShape);
228*89c4ff92SAndroid Build Coastguard Worker 
229*89c4ff92SAndroid Build Coastguard Worker     /// Attach an activation layer to the one passed as a parameter
230*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
231*89c4ff92SAndroid Build Coastguard Worker                                                       unsigned int outputSlot,
232*89c4ff92SAndroid Build Coastguard Worker                                                       tflite::ActivationFunctionType activationType);
233*89c4ff92SAndroid Build Coastguard Worker 
234*89c4ff92SAndroid Build Coastguard Worker     /// Attach a floor layer to the one passed as a parameter
235*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot);
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     // SupportedDataStorage's purpose is to hold data till we pass over to the network.
238*89c4ff92SAndroid Build Coastguard Worker     // We don't care about the content, and we want a single datatype to simplify the code.
239*89c4ff92SAndroid Build Coastguard Worker     struct SupportedDataStorage
240*89c4ff92SAndroid Build Coastguard Worker     {
241*89c4ff92SAndroid Build Coastguard Worker     public:
242*89c4ff92SAndroid Build Coastguard Worker         // Convenience constructors
243*89c4ff92SAndroid Build Coastguard Worker         SupportedDataStorage(std::unique_ptr<float[]>&&   data);
244*89c4ff92SAndroid Build Coastguard Worker         SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
245*89c4ff92SAndroid Build Coastguard Worker         SupportedDataStorage(std::unique_ptr<int8_t[]>&&  data);
246*89c4ff92SAndroid Build Coastguard Worker         SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker     private:
249*89c4ff92SAndroid Build Coastguard Worker         // Pointers to the data buffers
250*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<float[]>   m_FloatData;
251*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<uint8_t[]> m_Uint8Data;
252*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<int8_t[]>  m_Int8Data;
253*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<int32_t[]> m_Int32Data;
254*89c4ff92SAndroid Build Coastguard Worker     };
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker     bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
257*89c4ff92SAndroid Build Coastguard Worker 
258*89c4ff92SAndroid Build Coastguard Worker     bool IsConstTensor(TensorRawPtr tensorPtr);
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker     bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
261*89c4ff92SAndroid Build Coastguard Worker                                          armnn::DataType inputDataType,
262*89c4ff92SAndroid Build Coastguard Worker                                          armnn::DataType filterDataType);
263*89c4ff92SAndroid Build Coastguard Worker 
264*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
265*89c4ff92SAndroid Build Coastguard Worker                                                     armnn::TensorInfo& tensorInfo);
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     std::pair<armnn::ConstTensor, SupportedDataStorage>
268*89c4ff92SAndroid Build Coastguard Worker     CreateConstTensorPermuted(TensorRawPtr tensorPtr,
269*89c4ff92SAndroid Build Coastguard Worker                               armnn::TensorInfo& tensorInfo,
270*89c4ff92SAndroid Build Coastguard Worker                               armnn::Optional<armnn::PermutationVector&> permutationVector);
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
273*89c4ff92SAndroid Build Coastguard Worker     CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
274*89c4ff92SAndroid Build Coastguard Worker                                  armnn::TensorInfo& tensorInfo,
275*89c4ff92SAndroid Build Coastguard Worker                                  armnn::DataType inputDataType);
276*89c4ff92SAndroid Build Coastguard Worker 
277*89c4ff92SAndroid Build Coastguard Worker     template<typename T>
278*89c4ff92SAndroid Build Coastguard Worker     std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
279*89c4ff92SAndroid Build Coastguard Worker     CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
280*89c4ff92SAndroid Build Coastguard Worker                                   TfLiteParserImpl::TensorRawPtr tensorPtr,
281*89c4ff92SAndroid Build Coastguard Worker                                   armnn::TensorInfo& tensorInfo,
282*89c4ff92SAndroid Build Coastguard Worker                                   armnn::Optional<armnn::PermutationVector&> permutationVector);
283*89c4ff92SAndroid Build Coastguard Worker 
284*89c4ff92SAndroid Build Coastguard Worker     std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
285*89c4ff92SAndroid Build Coastguard Worker     CreateConstTensorPtr(TensorRawPtr tensorPtr,
286*89c4ff92SAndroid Build Coastguard Worker                          armnn::TensorInfo& inputTensorInfo);
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo InputTensorInfo(size_t subgraphIndex,
289*89c4ff92SAndroid Build Coastguard Worker                                       size_t operatorIndex,
290*89c4ff92SAndroid Build Coastguard Worker                                       int input);
291*89c4ff92SAndroid Build Coastguard Worker 
292*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex,
293*89c4ff92SAndroid Build Coastguard Worker                                        size_t operatorIndex,
294*89c4ff92SAndroid Build Coastguard Worker                                        armnn::IConnectableLayer* layer,
295*89c4ff92SAndroid Build Coastguard Worker                                        int output,
296*89c4ff92SAndroid Build Coastguard Worker                                        std::vector<int> inputs);
297*89c4ff92SAndroid Build Coastguard Worker 
298*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex,
299*89c4ff92SAndroid Build Coastguard Worker                                        size_t operatorIndex,
300*89c4ff92SAndroid Build Coastguard Worker                                        armnn::IConnectableLayer* layer,
301*89c4ff92SAndroid Build Coastguard Worker                                        int output = 0,
302*89c4ff92SAndroid Build Coastguard Worker                                        std::vector<armnn::TensorShape> inputShapes = {});
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker     /// Settings for configuring the TfLiteParser
305*89c4ff92SAndroid Build Coastguard Worker     armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker     /// The network we're building. Gets cleared after it is passed to the user
308*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr                    m_Network;
309*89c4ff92SAndroid Build Coastguard Worker     ModelPtr                              m_Model;
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     std::vector<OperatorParsingFunction>                     m_ParserFunctions;
312*89c4ff92SAndroid Build Coastguard Worker     std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
313*89c4ff92SAndroid Build Coastguard Worker 
314*89c4ff92SAndroid Build Coastguard Worker     /// A mapping of an output slot to each of the input slots it should be connected to
315*89c4ff92SAndroid Build Coastguard Worker     /// The outputSlot is from the layer that creates this tensor as one of its ouputs
316*89c4ff92SAndroid Build Coastguard Worker     /// The inputSlots are from the layers that use this tensor as one of their inputs
317*89c4ff92SAndroid Build Coastguard Worker     struct TensorSlots
318*89c4ff92SAndroid Build Coastguard Worker     {
319*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot* outputSlot;
320*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::IInputSlot*> inputSlots;
321*89c4ff92SAndroid Build Coastguard Worker 
TensorSlotsarmnnTfLiteParser::TfLiteParserImpl::TensorSlots322*89c4ff92SAndroid Build Coastguard Worker         TensorSlots() : outputSlot(nullptr) { }
323*89c4ff92SAndroid Build Coastguard Worker     };
324*89c4ff92SAndroid Build Coastguard Worker     typedef std::vector<TensorSlots> TensorConnections;
325*89c4ff92SAndroid Build Coastguard Worker     /// Connections for tensors in each subgraph
326*89c4ff92SAndroid Build Coastguard Worker     /// The first index is the subgraph ID, the second index is the tensor ID
327*89c4ff92SAndroid Build Coastguard Worker     std::vector<TensorConnections> m_SubgraphConnections;
328*89c4ff92SAndroid Build Coastguard Worker 
329*89c4ff92SAndroid Build Coastguard Worker     /// This is used in case that the model does not specify the output.
330*89c4ff92SAndroid Build Coastguard Worker     /// The shape can be calculated from the options.
331*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes;
332*89c4ff92SAndroid Build Coastguard Worker 
333*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> m_ConstantsToDequantize;
334*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> m_ConstantsToBeCreated;
335*89c4ff92SAndroid Build Coastguard Worker     std::map<size_t, armnn::TensorInfo> m_TensorInfos;
336*89c4ff92SAndroid Build Coastguard Worker };
337*89c4ff92SAndroid Build Coastguard Worker 
338*89c4ff92SAndroid Build Coastguard Worker }
339