xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/TfLiteParser.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Descriptors.hpp>
8 #include "armnn/INetwork.hpp"
9 #include "armnnTfLiteParser/ITfLiteParser.hpp"
10 #include "armnn/Types.hpp"
11 
12 #include <schema_generated.h>
13 #include <functional>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <tensorflow/lite/version.h>
18 
19 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
20 #define ARMNN_POST_TFLITE_2_3
21 #endif
22 
23 namespace armnnTfLiteParser
24 {
25 
26 class TfLiteParserImpl
27 {
28 public:
29     // Shorthands for TfLite types
30     using ModelPtr = std::unique_ptr<tflite::ModelT>;
31     using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
32     using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
33     using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
34     using TensorPtr = std::unique_ptr<tflite::TensorT>;
35     using TensorRawPtr = const tflite::TensorT *;
36     using TensorRawPtrVector = std::vector<TensorRawPtr>;
37     using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
38     using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
39     using BufferPtr = std::unique_ptr<tflite::BufferT>;
40     using BufferRawPtr = const tflite::BufferT *;
41 
42 public:
43     /// Create the network from a flatbuffers binary file on disk
44     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
45 
46     /// Create the network from a flatbuffers binary
47     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
48 
49 
50     /// Retrieve binding info (layer id and tensor info) for the network input identified by
51     /// the given layer name and subgraph id
52     BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
53                                                 const std::string& name) const;
54 
55     /// Retrieve binding info (layer id and tensor info) for the network output identified by
56     /// the given layer name and subgraph id
57     BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
58                                                          const std::string& name) const;
59 
60     /// Return the number of subgraphs in the parsed model
61     size_t GetSubgraphCount() const;
62 
63     /// Return the input tensor names for a given subgraph
64     std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
65 
66     /// Return the output tensor names for a given subgraph
67     std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
68 
69     TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
70     ~TfLiteParserImpl() = default;
71 
72 public:
73     // testable helpers
74     armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent);
75 
76     armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model);
77 
78     static ModelPtr LoadModelFromFile(const char* fileName);
79     static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len);
80     static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
81     static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
82     static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex);
83     static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex);
84     static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
85     static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
86 
87     static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
88     static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
89                                                   const armnn::TensorInfo& inputTensorInfo);
90     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
91                                                   const std::vector<int32_t>& targetDimsIn);
92 
93     /// Retrieve version in X.Y.Z form
94     static const std::string GetVersion();
95 
96 private:
97 
98     // No copying allowed until it is wanted and properly implemented
99     TfLiteParserImpl(const TfLiteParserImpl &) = delete;
100     TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete;
101 
102     /// Create the network from an already loaded flatbuffers model
103     armnn::INetworkPtr CreateNetworkFromModel();
104 
105     // signature for the parser functions
106     using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex);
107 
108     void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
109     void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
110 
111     void ParseAbs(size_t subgraphIndex, size_t operatorIndex);
112     void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
113     void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
114     void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction);
115     void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
116     void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
117     void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
118     void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
119     void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
120     void ParseCast(size_t subgraphIndex, size_t operatorIndex);
121     void ParseCeil(size_t subgraphIndex, size_t operatorIndex);
122     void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
123     void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
124     void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
125     // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
126     #if defined(ARMNN_POST_TFLITE_2_4)
127     void ParseConv3D(size_t subgraphIndex, size_t operatorIndex);
128     #endif
129     void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
130     void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
131     void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
132     void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
133     void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
134     void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation);
135     void ParseElu(size_t subgraphIndex, size_t operatorIndex);
136     void ParseEqual(size_t subgraphIndex, size_t operatorIndex);
137     void ParseExp(size_t subgraphIndex, size_t operatorIndex);
138     void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex);
139     void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex);
140     void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
141     void ParseGather(size_t subgraphIndex, size_t operatorIndex);
142     void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex);
143     void ParseGreater(size_t subgraphIndex, size_t operatorIndex);
144     void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex);
145     void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
146     void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
147     void ParseLess(size_t subgraphIndex, size_t operatorIndex);
148     void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex);
149     void ParseLog(size_t subgraphIndex, size_t operatorIndex);
150     void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex);
151     void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex);
152     void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
153     void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex);
154     void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
155     void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
156     void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
157     void ParseMean(size_t subgraphIndex, size_t operatorIndex);
158     void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
159     void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex);
160     void ParseMul(size_t subgraphIndex, size_t operatorIndex);
161     void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
162     void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex);
163     void ParsePack(size_t subgraphIndex, size_t operatorIndex);
164     void ParsePad(size_t subgraphIndex, size_t operatorIndex);
165     void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
166     void ParsePrelu(size_t subgraphIndex, size_t operatorIndex);
167     void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
168     void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation);
169     void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex);
170     void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex);
171     void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex);
172     void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
173     void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
174     void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
175     void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
176     void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
177     void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
178     void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex);
179     void ParseShape(size_t subgraphIndex, size_t operatorIndex);
180     void ParseSin(size_t subgraphIndex, size_t operatorIndex);
181     void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
182     void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
183     void ParseSqrt(size_t subgraphIndex, size_t operatorIndex);
184     void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
185     void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex);
186     void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
187     void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
188     void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
189     void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
190     void ParseSub(size_t subgraphIndex, size_t operatorIndex);
191     void ParseSum(size_t subgraphIndex, size_t operatorIndex);
192     void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
193     void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
194     void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
195     void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex);
196     void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
197 
198     void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
199     void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
200     void RegisterInputSlots(size_t subgraphIndex,
201                             size_t operatorIndex,
202                             armnn::IConnectableLayer* layer,
203                             const std::vector<unsigned int>& tensorIndexes,
204                             unsigned int startingSlotIndex = 0);
205     void RegisterOutputSlots(size_t subgraphIndex,
206                              size_t operatorIndex,
207                              armnn::IConnectableLayer* layer,
208                              const std::vector<unsigned int>& tensorIndexes);
209 
210     void SetupInputLayerTensorInfos(size_t subgraphIndex);
211     void SetupConstantLayerTensorInfos(size_t subgraphIndex);
212 
213     void SetupInputLayers(size_t subgraphIndex);
214     void SetupOutputLayers(size_t subgraphIndex);
215     void SetupConstantLayers(size_t subgraphIndex);
216 
217     void ResetParser();
218 
219     void AddBroadcastReshapeLayer(size_t subgraphIndex,
220                                   size_t operatorIndex,
221                                   armnn::IConnectableLayer* layer);
222 
223     /// Attach an reshape layer to the one passed as a parameter
224     armnn::IConnectableLayer* AddReshapeLayer(armnn::IConnectableLayer* layer,
225                                               unsigned int outputSlot,
226                                               std::string reshapeLayerName,
227                                               armnn::TensorInfo outputShape);
228 
229     /// Attach an activation layer to the one passed as a parameter
230     armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
231                                                       unsigned int outputSlot,
232                                                       tflite::ActivationFunctionType activationType);
233 
234     /// Attach a floor layer to the one passed as a parameter
235     armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot);
236 
237     // SupportedDataStorage's purpose is to hold data till we pass over to the network.
238     // We don't care about the content, and we want a single datatype to simplify the code.
239     struct SupportedDataStorage
240     {
241     public:
242         // Convenience constructors
243         SupportedDataStorage(std::unique_ptr<float[]>&&   data);
244         SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
245         SupportedDataStorage(std::unique_ptr<int8_t[]>&&  data);
246         SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
247 
248     private:
249         // Pointers to the data buffers
250         std::unique_ptr<float[]>   m_FloatData;
251         std::unique_ptr<uint8_t[]> m_Uint8Data;
252         std::unique_ptr<int8_t[]>  m_Int8Data;
253         std::unique_ptr<int32_t[]> m_Int32Data;
254     };
255 
256     bool ShouldConstantTensorBeCreated(unsigned int tensorIndex);
257 
258     bool IsConstTensor(TensorRawPtr tensorPtr);
259 
260     bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
261                                          armnn::DataType inputDataType,
262                                          armnn::DataType filterDataType);
263 
264     armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
265                                                     armnn::TensorInfo& tensorInfo);
266 
267     std::pair<armnn::ConstTensor, SupportedDataStorage>
268     CreateConstTensorPermuted(TensorRawPtr tensorPtr,
269                               armnn::TensorInfo& tensorInfo,
270                               armnn::Optional<armnn::PermutationVector&> permutationVector);
271 
272     std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
273     CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
274                                  armnn::TensorInfo& tensorInfo,
275                                  armnn::DataType inputDataType);
276 
277     template<typename T>
278     std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
279     CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
280                                   TfLiteParserImpl::TensorRawPtr tensorPtr,
281                                   armnn::TensorInfo& tensorInfo,
282                                   armnn::Optional<armnn::PermutationVector&> permutationVector);
283 
284     std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
285     CreateConstTensorPtr(TensorRawPtr tensorPtr,
286                          armnn::TensorInfo& inputTensorInfo);
287 
288     armnn::TensorInfo InputTensorInfo(size_t subgraphIndex,
289                                       size_t operatorIndex,
290                                       int input);
291 
292     armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex,
293                                        size_t operatorIndex,
294                                        armnn::IConnectableLayer* layer,
295                                        int output,
296                                        std::vector<int> inputs);
297 
298     armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex,
299                                        size_t operatorIndex,
300                                        armnn::IConnectableLayer* layer,
301                                        int output = 0,
302                                        std::vector<armnn::TensorShape> inputShapes = {});
303 
304     /// Settings for configuring the TfLiteParser
305     armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
306 
307     /// The network we're building. Gets cleared after it is passed to the user
308     armnn::INetworkPtr                    m_Network;
309     ModelPtr                              m_Model;
310 
311     std::vector<OperatorParsingFunction>                     m_ParserFunctions;
312     std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
313 
314     /// A mapping of an output slot to each of the input slots it should be connected to
315     /// The outputSlot is from the layer that creates this tensor as one of its ouputs
316     /// The inputSlots are from the layers that use this tensor as one of their inputs
317     struct TensorSlots
318     {
319         armnn::IOutputSlot* outputSlot;
320         std::vector<armnn::IInputSlot*> inputSlots;
321 
TensorSlotsarmnnTfLiteParser::TfLiteParserImpl::TensorSlots322         TensorSlots() : outputSlot(nullptr) { }
323     };
324     typedef std::vector<TensorSlots> TensorConnections;
325     /// Connections for tensors in each subgraph
326     /// The first index is the subgraph ID, the second index is the tensor ID
327     std::vector<TensorConnections> m_SubgraphConnections;
328 
329     /// This is used in case that the model does not specify the output.
330     /// The shape can be calculated from the options.
331     std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes;
332 
333     std::vector<unsigned int> m_ConstantsToDequantize;
334     std::vector<unsigned int> m_ConstantsToBeCreated;
335     std::map<size_t, armnn::TensorInfo> m_TensorInfos;
336 };
337 
338 }
339