1 // 2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <armnn/Descriptors.hpp> 8 #include "armnn/INetwork.hpp" 9 #include "armnnTfLiteParser/ITfLiteParser.hpp" 10 #include "armnn/Types.hpp" 11 12 #include <schema_generated.h> 13 #include <functional> 14 #include <unordered_map> 15 #include <vector> 16 17 #include <tensorflow/lite/version.h> 18 19 #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) 20 #define ARMNN_POST_TFLITE_2_3 21 #endif 22 23 namespace armnnTfLiteParser 24 { 25 26 class TfLiteParserImpl 27 { 28 public: 29 // Shorthands for TfLite types 30 using ModelPtr = std::unique_ptr<tflite::ModelT>; 31 using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>; 32 using OperatorPtr = std::unique_ptr<tflite::OperatorT>; 33 using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>; 34 using TensorPtr = std::unique_ptr<tflite::TensorT>; 35 using TensorRawPtr = const tflite::TensorT *; 36 using TensorRawPtrVector = std::vector<TensorRawPtr>; 37 using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>; 38 using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>; 39 using BufferPtr = std::unique_ptr<tflite::BufferT>; 40 using BufferRawPtr = const tflite::BufferT *; 41 42 public: 43 /// Create the network from a flatbuffers binary file on disk 44 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); 45 46 /// Create the network from a flatbuffers binary 47 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent); 48 49 50 /// Retrieve binding info (layer id and tensor info) for the network input identified by 51 /// the given layer name and subgraph id 52 BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, 53 const std::string& name) const; 54 55 /// Retrieve binding info (layer id and tensor info) for the network output identified by 56 /// the given layer name and subgraph id 57 BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, 58 const std::string& name) const; 59 60 /// Return the number of subgraphs in the parsed model 61 size_t GetSubgraphCount() const; 62 63 /// Return the input tensor names for a given subgraph 64 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const; 65 66 /// Return the output tensor names for a given subgraph 67 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const; 68 69 TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional()); 70 ~TfLiteParserImpl() = default; 71 72 public: 73 // testable helpers 74 armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent); 75 76 armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model); 77 78 static ModelPtr LoadModelFromFile(const char* fileName); 79 static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len); 80 static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 81 static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 82 static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex); 83 static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex); 84 static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 85 static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 86 87 static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex); 88 static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims, 89 const armnn::TensorInfo& inputTensorInfo); 90 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, 91 const std::vector<int32_t>& targetDimsIn); 92 93 /// Retrieve version in X.Y.Z form 94 static const std::string GetVersion(); 95 96 private: 97 98 // No copying allowed until it is wanted and properly implemented 99 TfLiteParserImpl(const TfLiteParserImpl &) = delete; 100 TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete; 101 102 /// Create the network from an already loaded flatbuffers model 103 armnn::INetworkPtr CreateNetworkFromModel(); 104 105 // signature for the parser functions 106 using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex); 107 108 void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex); 109 void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex); 110 111 void ParseAbs(size_t subgraphIndex, size_t operatorIndex); 112 void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType); 113 void ParseAdd(size_t subgraphIndex, size_t operatorIndex); 114 void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction); 115 void ParseArgMin(size_t subgraphIndex, size_t operatorIndex); 116 void ParseArgMax(size_t subgraphIndex, size_t operatorIndex); 117 void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex); 118 void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex); 119 void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex); 120 void ParseCast(size_t subgraphIndex, size_t operatorIndex); 121 void ParseCeil(size_t subgraphIndex, size_t operatorIndex); 122 void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation); 123 void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex); 124 void ParseConv2D(size_t subgraphIndex, size_t operatorIndex); 125 // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed. 126 #if defined(ARMNN_POST_TFLITE_2_4) 127 void ParseConv3D(size_t subgraphIndex, size_t operatorIndex); 128 #endif 129 void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex); 130 void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex); 131 void ParseDequantize(size_t subgraphIndex, size_t operatorIndex); 132 void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex); 133 void ParseDiv(size_t subgraphIndex, size_t operatorIndex); 134 void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation); 135 void ParseElu(size_t subgraphIndex, size_t operatorIndex); 136 void ParseEqual(size_t subgraphIndex, size_t operatorIndex); 137 void ParseExp(size_t subgraphIndex, size_t operatorIndex); 138 void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex); 139 void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex); 140 void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex); 141 void ParseGather(size_t subgraphIndex, size_t operatorIndex); 142 void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex); 143 void ParseGreater(size_t subgraphIndex, size_t operatorIndex); 144 void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex); 145 void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex); 146 void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex); 147 void ParseLess(size_t subgraphIndex, size_t operatorIndex); 148 void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex); 149 void ParseLog(size_t subgraphIndex, size_t operatorIndex); 150 void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex); 151 void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex); 152 void ParseLogistic(size_t subgraphIndex, size_t operatorIndex); 153 void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex); 154 void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex); 155 void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex); 156 void ParseMaximum(size_t subgraphIndex, size_t operatorIndex); 157 void ParseMean(size_t subgraphIndex, size_t operatorIndex); 158 void ParseMinimum(size_t subgraphIndex, size_t operatorIndex); 159 void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex); 160 void ParseMul(size_t subgraphIndex, size_t operatorIndex); 161 void ParseNeg(size_t subgraphIndex, size_t operatorIndex); 162 void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex); 163 void ParsePack(size_t subgraphIndex, size_t operatorIndex); 164 void ParsePad(size_t subgraphIndex, size_t operatorIndex); 165 void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm); 166 void ParsePrelu(size_t subgraphIndex, size_t operatorIndex); 167 void ParseQuantize(size_t subgraphIndex, size_t operatorIndex); 168 void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation); 169 void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex); 170 void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex); 171 void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex); 172 void ParseRelu(size_t subgraphIndex, size_t operatorIndex); 173 void ParseRelu6(size_t subgraphIndex, size_t operatorIndex); 174 void ParseReshape(size_t subgraphIndex, size_t operatorIndex); 175 void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod); 176 void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex); 177 void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex); 178 void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex); 179 void ParseShape(size_t subgraphIndex, size_t operatorIndex); 180 void ParseSin(size_t subgraphIndex, size_t operatorIndex); 181 void ParseSlice(size_t subgraphIndex, size_t operatorIndex); 182 void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex); 183 void ParseSqrt(size_t subgraphIndex, size_t operatorIndex); 184 void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex); 185 void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex); 186 void ParseSplit(size_t subgraphIndex, size_t operatorIndex); 187 void ParseSplitV(size_t subgraphIndex, size_t operatorIndex); 188 void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex); 189 void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex); 190 void ParseSub(size_t subgraphIndex, size_t operatorIndex); 191 void ParseSum(size_t subgraphIndex, size_t operatorIndex); 192 void ParseTanH(size_t subgraphIndex, size_t operatorIndex); 193 void ParseTranspose(size_t subgraphIndex, size_t operatorIndex); 194 void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex); 195 void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex); 196 void ParseUnpack(size_t subgraphIndex, size_t operatorIndex); 197 198 void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot); 199 void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot); 200 void RegisterInputSlots(size_t subgraphIndex, 201 size_t operatorIndex, 202 armnn::IConnectableLayer* layer, 203 const std::vector<unsigned int>& tensorIndexes, 204 unsigned int startingSlotIndex = 0); 205 void RegisterOutputSlots(size_t subgraphIndex, 206 size_t operatorIndex, 207 armnn::IConnectableLayer* layer, 208 const std::vector<unsigned int>& tensorIndexes); 209 210 void SetupInputLayerTensorInfos(size_t subgraphIndex); 211 void SetupConstantLayerTensorInfos(size_t subgraphIndex); 212 213 void SetupInputLayers(size_t subgraphIndex); 214 void SetupOutputLayers(size_t subgraphIndex); 215 void SetupConstantLayers(size_t subgraphIndex); 216 217 void ResetParser(); 218 219 void AddBroadcastReshapeLayer(size_t subgraphIndex, 220 size_t operatorIndex, 221 armnn::IConnectableLayer* layer); 222 223 /// Attach an reshape layer to the one passed as a parameter 224 armnn::IConnectableLayer* AddReshapeLayer(armnn::IConnectableLayer* layer, 225 unsigned int outputSlot, 226 std::string reshapeLayerName, 227 armnn::TensorInfo outputShape); 228 229 /// Attach an activation layer to the one passed as a parameter 230 armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer, 231 unsigned int outputSlot, 232 tflite::ActivationFunctionType activationType); 233 234 /// Attach a floor layer to the one passed as a parameter 235 armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot); 236 237 // SupportedDataStorage's purpose is to hold data till we pass over to the network. 238 // We don't care about the content, and we want a single datatype to simplify the code. 239 struct SupportedDataStorage 240 { 241 public: 242 // Convenience constructors 243 SupportedDataStorage(std::unique_ptr<float[]>&& data); 244 SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data); 245 SupportedDataStorage(std::unique_ptr<int8_t[]>&& data); 246 SupportedDataStorage(std::unique_ptr<int32_t[]>&& data); 247 248 private: 249 // Pointers to the data buffers 250 std::unique_ptr<float[]> m_FloatData; 251 std::unique_ptr<uint8_t[]> m_Uint8Data; 252 std::unique_ptr<int8_t[]> m_Int8Data; 253 std::unique_ptr<int32_t[]> m_Int32Data; 254 }; 255 256 bool ShouldConstantTensorBeCreated(unsigned int tensorIndex); 257 258 bool IsConstTensor(TensorRawPtr tensorPtr); 259 260 bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr, 261 armnn::DataType inputDataType, 262 armnn::DataType filterDataType); 263 264 armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, 265 armnn::TensorInfo& tensorInfo); 266 267 std::pair<armnn::ConstTensor, SupportedDataStorage> 268 CreateConstTensorPermuted(TensorRawPtr tensorPtr, 269 armnn::TensorInfo& tensorInfo, 270 armnn::Optional<armnn::PermutationVector&> permutationVector); 271 272 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> 273 CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, 274 armnn::TensorInfo& tensorInfo, 275 armnn::DataType inputDataType); 276 277 template<typename T> 278 std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage> 279 CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr, 280 TfLiteParserImpl::TensorRawPtr tensorPtr, 281 armnn::TensorInfo& tensorInfo, 282 armnn::Optional<armnn::PermutationVector&> permutationVector); 283 284 std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>> 285 CreateConstTensorPtr(TensorRawPtr tensorPtr, 286 armnn::TensorInfo& inputTensorInfo); 287 288 armnn::TensorInfo InputTensorInfo(size_t subgraphIndex, 289 size_t operatorIndex, 290 int input); 291 292 armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex, 293 size_t operatorIndex, 294 armnn::IConnectableLayer* layer, 295 int output, 296 std::vector<int> inputs); 297 298 armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex, 299 size_t operatorIndex, 300 armnn::IConnectableLayer* layer, 301 int output = 0, 302 std::vector<armnn::TensorShape> inputShapes = {}); 303 304 /// Settings for configuring the TfLiteParser 305 armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options; 306 307 /// The network we're building. Gets cleared after it is passed to the user 308 armnn::INetworkPtr m_Network; 309 ModelPtr m_Model; 310 311 std::vector<OperatorParsingFunction> m_ParserFunctions; 312 std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions; 313 314 /// A mapping of an output slot to each of the input slots it should be connected to 315 /// The outputSlot is from the layer that creates this tensor as one of its ouputs 316 /// The inputSlots are from the layers that use this tensor as one of their inputs 317 struct TensorSlots 318 { 319 armnn::IOutputSlot* outputSlot; 320 std::vector<armnn::IInputSlot*> inputSlots; 321 TensorSlotsarmnnTfLiteParser::TfLiteParserImpl::TensorSlots322 TensorSlots() : outputSlot(nullptr) { } 323 }; 324 typedef std::vector<TensorSlots> TensorConnections; 325 /// Connections for tensors in each subgraph 326 /// The first index is the subgraph ID, the second index is the tensor ID 327 std::vector<TensorConnections> m_SubgraphConnections; 328 329 /// This is used in case that the model does not specify the output. 330 /// The shape can be calculated from the options. 331 std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes; 332 333 std::vector<unsigned int> m_ConstantsToDequantize; 334 std::vector<unsigned int> m_ConstantsToBeCreated; 335 std::map<size_t, armnn::TensorInfo> m_TensorInfos; 336 }; 337 338 } 339