1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include "armnn/INetwork.hpp" 9*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/ITfLiteParser.hpp" 10*89c4ff92SAndroid Build Coastguard Worker #include "armnn/Types.hpp" 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h> 13*89c4ff92SAndroid Build Coastguard Worker #include <functional> 14*89c4ff92SAndroid Build Coastguard Worker #include <unordered_map> 15*89c4ff92SAndroid Build Coastguard Worker #include <vector> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h> 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) 20*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_3 21*89c4ff92SAndroid Build Coastguard Worker #endif 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker namespace armnnTfLiteParser 24*89c4ff92SAndroid Build Coastguard Worker { 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker class TfLiteParserImpl 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker public: 29*89c4ff92SAndroid Build Coastguard Worker // Shorthands for TfLite types 30*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = std::unique_ptr<tflite::ModelT>; 31*89c4ff92SAndroid Build Coastguard Worker using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>; 32*89c4ff92SAndroid Build Coastguard Worker using OperatorPtr = std::unique_ptr<tflite::OperatorT>; 33*89c4ff92SAndroid Build Coastguard Worker using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>; 34*89c4ff92SAndroid Build Coastguard Worker using TensorPtr = std::unique_ptr<tflite::TensorT>; 35*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtr = const tflite::TensorT *; 36*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtrVector = std::vector<TensorRawPtr>; 37*89c4ff92SAndroid Build Coastguard Worker using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>; 38*89c4ff92SAndroid Build Coastguard Worker using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>; 39*89c4ff92SAndroid Build Coastguard Worker using BufferPtr = std::unique_ptr<tflite::BufferT>; 40*89c4ff92SAndroid Build Coastguard Worker using BufferRawPtr = const tflite::BufferT *; 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker public: 43*89c4ff92SAndroid Build Coastguard Worker /// Create the network from a flatbuffers binary file on disk 44*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker /// Create the network from a flatbuffers binary 47*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent); 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker /// Retrieve binding info (layer id and tensor info) for the network input identified by 51*89c4ff92SAndroid Build Coastguard Worker /// the given layer name and subgraph id 52*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, 53*89c4ff92SAndroid Build Coastguard Worker const std::string& name) const; 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker /// Retrieve binding info (layer id and tensor info) for the network output identified by 56*89c4ff92SAndroid Build Coastguard Worker /// the given layer name and subgraph id 57*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, 58*89c4ff92SAndroid Build Coastguard Worker const std::string& name) const; 59*89c4ff92SAndroid Build Coastguard Worker 60*89c4ff92SAndroid Build Coastguard Worker /// Return the number of subgraphs in the parsed model 61*89c4ff92SAndroid Build Coastguard Worker size_t GetSubgraphCount() const; 62*89c4ff92SAndroid Build Coastguard Worker 63*89c4ff92SAndroid Build Coastguard Worker /// Return the input tensor names for a given subgraph 64*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const; 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker /// Return the output tensor names for a given subgraph 67*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const; 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional()); 70*89c4ff92SAndroid Build Coastguard Worker ~TfLiteParserImpl() = default; 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker public: 73*89c4ff92SAndroid Build Coastguard Worker // testable helpers 74*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinaryAsDynamic(const std::vector<uint8_t>& binaryContent); 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr LoadModel(std::unique_ptr<tflite::ModelT> model); 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker static ModelPtr LoadModelFromFile(const char* fileName); 79*89c4ff92SAndroid Build Coastguard Worker static ModelPtr LoadModelFromBinary(const uint8_t* binaryContent, size_t len); 80*89c4ff92SAndroid Build Coastguard Worker static TensorRawPtrVector GetInputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 81*89c4ff92SAndroid Build Coastguard Worker static TensorRawPtrVector GetOutputs(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 82*89c4ff92SAndroid Build Coastguard Worker static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr& model, size_t subgraphIndex); 83*89c4ff92SAndroid Build Coastguard Worker static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr& model, size_t subgraphIndex); 84*89c4ff92SAndroid Build Coastguard Worker static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 85*89c4ff92SAndroid Build Coastguard Worker static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex); 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex); 88*89c4ff92SAndroid Build Coastguard Worker static armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims, 89*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo); 90*89c4ff92SAndroid Build Coastguard Worker static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, 91*89c4ff92SAndroid Build Coastguard Worker const std::vector<int32_t>& targetDimsIn); 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker /// Retrieve version in X.Y.Z form 94*89c4ff92SAndroid Build Coastguard Worker static const std::string GetVersion(); 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker private: 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker // No copying allowed until it is wanted and properly implemented 99*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl(const TfLiteParserImpl &) = delete; 100*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete; 101*89c4ff92SAndroid Build Coastguard Worker 102*89c4ff92SAndroid Build Coastguard Worker /// Create the network from an already loaded flatbuffers model 103*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromModel(); 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker // signature for the parser functions 106*89c4ff92SAndroid Build Coastguard Worker using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex); 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex); 109*89c4ff92SAndroid Build Coastguard Worker void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex); 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker void ParseAbs(size_t subgraphIndex, size_t operatorIndex); 112*89c4ff92SAndroid Build Coastguard Worker void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType); 113*89c4ff92SAndroid Build Coastguard Worker void ParseAdd(size_t subgraphIndex, size_t operatorIndex); 114*89c4ff92SAndroid Build Coastguard Worker void ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, armnn::ArgMinMaxFunction argMinMaxFunction); 115*89c4ff92SAndroid Build Coastguard Worker void ParseArgMin(size_t subgraphIndex, size_t operatorIndex); 116*89c4ff92SAndroid Build Coastguard Worker void ParseArgMax(size_t subgraphIndex, size_t operatorIndex); 117*89c4ff92SAndroid Build Coastguard Worker void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex); 118*89c4ff92SAndroid Build Coastguard Worker void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex); 119*89c4ff92SAndroid Build Coastguard Worker void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex); 120*89c4ff92SAndroid Build Coastguard Worker void ParseCast(size_t subgraphIndex, size_t operatorIndex); 121*89c4ff92SAndroid Build Coastguard Worker void ParseCeil(size_t subgraphIndex, size_t operatorIndex); 122*89c4ff92SAndroid Build Coastguard Worker void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation); 123*89c4ff92SAndroid Build Coastguard Worker void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex); 124*89c4ff92SAndroid Build Coastguard Worker void ParseConv2D(size_t subgraphIndex, size_t operatorIndex); 125*89c4ff92SAndroid Build Coastguard Worker // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed. 126*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_4) 127*89c4ff92SAndroid Build Coastguard Worker void ParseConv3D(size_t subgraphIndex, size_t operatorIndex); 128*89c4ff92SAndroid Build Coastguard Worker #endif 129*89c4ff92SAndroid Build Coastguard Worker void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex); 130*89c4ff92SAndroid Build Coastguard Worker void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex); 131*89c4ff92SAndroid Build Coastguard Worker void ParseDequantize(size_t subgraphIndex, size_t operatorIndex); 132*89c4ff92SAndroid Build Coastguard Worker void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex); 133*89c4ff92SAndroid Build Coastguard Worker void ParseDiv(size_t subgraphIndex, size_t operatorIndex); 134*89c4ff92SAndroid Build Coastguard Worker void ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, armnn::UnaryOperation unaryOperation); 135*89c4ff92SAndroid Build Coastguard Worker void ParseElu(size_t subgraphIndex, size_t operatorIndex); 136*89c4ff92SAndroid Build Coastguard Worker void ParseEqual(size_t subgraphIndex, size_t operatorIndex); 137*89c4ff92SAndroid Build Coastguard Worker void ParseExp(size_t subgraphIndex, size_t operatorIndex); 138*89c4ff92SAndroid Build Coastguard Worker void ParseExpandDims(size_t subgraphIndex, size_t operatorIndex); 139*89c4ff92SAndroid Build Coastguard Worker void ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex); 140*89c4ff92SAndroid Build Coastguard Worker void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex); 141*89c4ff92SAndroid Build Coastguard Worker void ParseGather(size_t subgraphIndex, size_t operatorIndex); 142*89c4ff92SAndroid Build Coastguard Worker void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex); 143*89c4ff92SAndroid Build Coastguard Worker void ParseGreater(size_t subgraphIndex, size_t operatorIndex); 144*89c4ff92SAndroid Build Coastguard Worker void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex); 145*89c4ff92SAndroid Build Coastguard Worker void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex); 146*89c4ff92SAndroid Build Coastguard Worker void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex); 147*89c4ff92SAndroid Build Coastguard Worker void ParseLess(size_t subgraphIndex, size_t operatorIndex); 148*89c4ff92SAndroid Build Coastguard Worker void ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex); 149*89c4ff92SAndroid Build Coastguard Worker void ParseLog(size_t subgraphIndex, size_t operatorIndex); 150*89c4ff92SAndroid Build Coastguard Worker void ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex); 151*89c4ff92SAndroid Build Coastguard Worker void ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex); 152*89c4ff92SAndroid Build Coastguard Worker void ParseLogistic(size_t subgraphIndex, size_t operatorIndex); 153*89c4ff92SAndroid Build Coastguard Worker void ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex); 154*89c4ff92SAndroid Build Coastguard Worker void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex); 155*89c4ff92SAndroid Build Coastguard Worker void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex); 156*89c4ff92SAndroid Build Coastguard Worker void ParseMaximum(size_t subgraphIndex, size_t operatorIndex); 157*89c4ff92SAndroid Build Coastguard Worker void ParseMean(size_t subgraphIndex, size_t operatorIndex); 158*89c4ff92SAndroid Build Coastguard Worker void ParseMinimum(size_t subgraphIndex, size_t operatorIndex); 159*89c4ff92SAndroid Build Coastguard Worker void ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex); 160*89c4ff92SAndroid Build Coastguard Worker void ParseMul(size_t subgraphIndex, size_t operatorIndex); 161*89c4ff92SAndroid Build Coastguard Worker void ParseNeg(size_t subgraphIndex, size_t operatorIndex); 162*89c4ff92SAndroid Build Coastguard Worker void ParseNotEqual(size_t subgraphIndex, size_t operatorIndex); 163*89c4ff92SAndroid Build Coastguard Worker void ParsePack(size_t subgraphIndex, size_t operatorIndex); 164*89c4ff92SAndroid Build Coastguard Worker void ParsePad(size_t subgraphIndex, size_t operatorIndex); 165*89c4ff92SAndroid Build Coastguard Worker void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm); 166*89c4ff92SAndroid Build Coastguard Worker void ParsePrelu(size_t subgraphIndex, size_t operatorIndex); 167*89c4ff92SAndroid Build Coastguard Worker void ParseQuantize(size_t subgraphIndex, size_t operatorIndex); 168*89c4ff92SAndroid Build Coastguard Worker void ParseReduce(size_t subgraphIndex, size_t operatorIndex, armnn::ReduceOperation reduceOperation); 169*89c4ff92SAndroid Build Coastguard Worker void ParseReduceMax(size_t subgraphIndex, size_t operatorIndex); 170*89c4ff92SAndroid Build Coastguard Worker void ParseReduceMin(size_t subgraphIndex, size_t operatorIndex); 171*89c4ff92SAndroid Build Coastguard Worker void ParseReduceProd(size_t subgraphIndex, size_t operatorIndex); 172*89c4ff92SAndroid Build Coastguard Worker void ParseRelu(size_t subgraphIndex, size_t operatorIndex); 173*89c4ff92SAndroid Build Coastguard Worker void ParseRelu6(size_t subgraphIndex, size_t operatorIndex); 174*89c4ff92SAndroid Build Coastguard Worker void ParseReshape(size_t subgraphIndex, size_t operatorIndex); 175*89c4ff92SAndroid Build Coastguard Worker void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod); 176*89c4ff92SAndroid Build Coastguard Worker void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex); 177*89c4ff92SAndroid Build Coastguard Worker void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex); 178*89c4ff92SAndroid Build Coastguard Worker void ParseRsqrt(size_t subgraphIndex, size_t operatorIndex); 179*89c4ff92SAndroid Build Coastguard Worker void ParseShape(size_t subgraphIndex, size_t operatorIndex); 180*89c4ff92SAndroid Build Coastguard Worker void ParseSin(size_t subgraphIndex, size_t operatorIndex); 181*89c4ff92SAndroid Build Coastguard Worker void ParseSlice(size_t subgraphIndex, size_t operatorIndex); 182*89c4ff92SAndroid Build Coastguard Worker void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex); 183*89c4ff92SAndroid Build Coastguard Worker void ParseSqrt(size_t subgraphIndex, size_t operatorIndex); 184*89c4ff92SAndroid Build Coastguard Worker void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex); 185*89c4ff92SAndroid Build Coastguard Worker void ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex); 186*89c4ff92SAndroid Build Coastguard Worker void ParseSplit(size_t subgraphIndex, size_t operatorIndex); 187*89c4ff92SAndroid Build Coastguard Worker void ParseSplitV(size_t subgraphIndex, size_t operatorIndex); 188*89c4ff92SAndroid Build Coastguard Worker void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex); 189*89c4ff92SAndroid Build Coastguard Worker void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex); 190*89c4ff92SAndroid Build Coastguard Worker void ParseSub(size_t subgraphIndex, size_t operatorIndex); 191*89c4ff92SAndroid Build Coastguard Worker void ParseSum(size_t subgraphIndex, size_t operatorIndex); 192*89c4ff92SAndroid Build Coastguard Worker void ParseTanH(size_t subgraphIndex, size_t operatorIndex); 193*89c4ff92SAndroid Build Coastguard Worker void ParseTranspose(size_t subgraphIndex, size_t operatorIndex); 194*89c4ff92SAndroid Build Coastguard Worker void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex); 195*89c4ff92SAndroid Build Coastguard Worker void ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex); 196*89c4ff92SAndroid Build Coastguard Worker void ParseUnpack(size_t subgraphIndex, size_t operatorIndex); 197*89c4ff92SAndroid Build Coastguard Worker 198*89c4ff92SAndroid Build Coastguard Worker void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot); 199*89c4ff92SAndroid Build Coastguard Worker void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot); 200*89c4ff92SAndroid Build Coastguard Worker void RegisterInputSlots(size_t subgraphIndex, 201*89c4ff92SAndroid Build Coastguard Worker size_t operatorIndex, 202*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer, 203*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& tensorIndexes, 204*89c4ff92SAndroid Build Coastguard Worker unsigned int startingSlotIndex = 0); 205*89c4ff92SAndroid Build Coastguard Worker void RegisterOutputSlots(size_t subgraphIndex, 206*89c4ff92SAndroid Build Coastguard Worker size_t operatorIndex, 207*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer, 208*89c4ff92SAndroid Build Coastguard Worker const std::vector<unsigned int>& tensorIndexes); 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker void SetupInputLayerTensorInfos(size_t subgraphIndex); 211*89c4ff92SAndroid Build Coastguard Worker void SetupConstantLayerTensorInfos(size_t subgraphIndex); 212*89c4ff92SAndroid Build Coastguard Worker 213*89c4ff92SAndroid Build Coastguard Worker void SetupInputLayers(size_t subgraphIndex); 214*89c4ff92SAndroid Build Coastguard Worker void SetupOutputLayers(size_t subgraphIndex); 215*89c4ff92SAndroid Build Coastguard Worker void SetupConstantLayers(size_t subgraphIndex); 216*89c4ff92SAndroid Build Coastguard Worker 217*89c4ff92SAndroid Build Coastguard Worker void ResetParser(); 218*89c4ff92SAndroid Build Coastguard Worker 219*89c4ff92SAndroid Build Coastguard Worker void AddBroadcastReshapeLayer(size_t subgraphIndex, 220*89c4ff92SAndroid Build Coastguard Worker size_t operatorIndex, 221*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer); 222*89c4ff92SAndroid Build Coastguard Worker 223*89c4ff92SAndroid Build Coastguard Worker /// Attach an reshape layer to the one passed as a parameter 224*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* AddReshapeLayer(armnn::IConnectableLayer* layer, 225*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSlot, 226*89c4ff92SAndroid Build Coastguard Worker std::string reshapeLayerName, 227*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputShape); 228*89c4ff92SAndroid Build Coastguard Worker 229*89c4ff92SAndroid Build Coastguard Worker /// Attach an activation layer to the one passed as a parameter 230*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer, 231*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSlot, 232*89c4ff92SAndroid Build Coastguard Worker tflite::ActivationFunctionType activationType); 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker /// Attach a floor layer to the one passed as a parameter 235*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* AddFusedFloorLayer(armnn::IConnectableLayer* layer, unsigned int outputSlot); 236*89c4ff92SAndroid Build Coastguard Worker 237*89c4ff92SAndroid Build Coastguard Worker // SupportedDataStorage's purpose is to hold data till we pass over to the network. 238*89c4ff92SAndroid Build Coastguard Worker // We don't care about the content, and we want a single datatype to simplify the code. 239*89c4ff92SAndroid Build Coastguard Worker struct SupportedDataStorage 240*89c4ff92SAndroid Build Coastguard Worker { 241*89c4ff92SAndroid Build Coastguard Worker public: 242*89c4ff92SAndroid Build Coastguard Worker // Convenience constructors 243*89c4ff92SAndroid Build Coastguard Worker SupportedDataStorage(std::unique_ptr<float[]>&& data); 244*89c4ff92SAndroid Build Coastguard Worker SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data); 245*89c4ff92SAndroid Build Coastguard Worker SupportedDataStorage(std::unique_ptr<int8_t[]>&& data); 246*89c4ff92SAndroid Build Coastguard Worker SupportedDataStorage(std::unique_ptr<int32_t[]>&& data); 247*89c4ff92SAndroid Build Coastguard Worker 248*89c4ff92SAndroid Build Coastguard Worker private: 249*89c4ff92SAndroid Build Coastguard Worker // Pointers to the data buffers 250*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> m_FloatData; 251*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<uint8_t[]> m_Uint8Data; 252*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<int8_t[]> m_Int8Data; 253*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<int32_t[]> m_Int32Data; 254*89c4ff92SAndroid Build Coastguard Worker }; 255*89c4ff92SAndroid Build Coastguard Worker 256*89c4ff92SAndroid Build Coastguard Worker bool ShouldConstantTensorBeCreated(unsigned int tensorIndex); 257*89c4ff92SAndroid Build Coastguard Worker 258*89c4ff92SAndroid Build Coastguard Worker bool IsConstTensor(TensorRawPtr tensorPtr); 259*89c4ff92SAndroid Build Coastguard Worker 260*89c4ff92SAndroid Build Coastguard Worker bool ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr, 261*89c4ff92SAndroid Build Coastguard Worker armnn::DataType inputDataType, 262*89c4ff92SAndroid Build Coastguard Worker armnn::DataType filterDataType); 263*89c4ff92SAndroid Build Coastguard Worker 264*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, 265*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& tensorInfo); 266*89c4ff92SAndroid Build Coastguard Worker 267*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, SupportedDataStorage> 268*89c4ff92SAndroid Build Coastguard Worker CreateConstTensorPermuted(TensorRawPtr tensorPtr, 269*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& tensorInfo, 270*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::PermutationVector&> permutationVector); 271*89c4ff92SAndroid Build Coastguard Worker 272*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> 273*89c4ff92SAndroid Build Coastguard Worker CreateConstTensorNonPermuted(TensorRawPtr tensorPtr, 274*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& tensorInfo, 275*89c4ff92SAndroid Build Coastguard Worker armnn::DataType inputDataType); 276*89c4ff92SAndroid Build Coastguard Worker 277*89c4ff92SAndroid Build Coastguard Worker template<typename T> 278*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage> 279*89c4ff92SAndroid Build Coastguard Worker CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr, 280*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtr tensorPtr, 281*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& tensorInfo, 282*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::PermutationVector&> permutationVector); 283*89c4ff92SAndroid Build Coastguard Worker 284*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>> 285*89c4ff92SAndroid Build Coastguard Worker CreateConstTensorPtr(TensorRawPtr tensorPtr, 286*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo& inputTensorInfo); 287*89c4ff92SAndroid Build Coastguard Worker 288*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo InputTensorInfo(size_t subgraphIndex, 289*89c4ff92SAndroid Build Coastguard Worker size_t operatorIndex, 290*89c4ff92SAndroid Build Coastguard Worker int input); 291*89c4ff92SAndroid Build Coastguard Worker 292*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo OutputTensorInfoFromInputs(size_t subgraphIndex, 293*89c4ff92SAndroid Build Coastguard Worker size_t operatorIndex, 294*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer, 295*89c4ff92SAndroid Build Coastguard Worker int output, 296*89c4ff92SAndroid Build Coastguard Worker std::vector<int> inputs); 297*89c4ff92SAndroid Build Coastguard Worker 298*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo OutputTensorInfoFromShapes(size_t subgraphIndex, 299*89c4ff92SAndroid Build Coastguard Worker size_t operatorIndex, 300*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer, 301*89c4ff92SAndroid Build Coastguard Worker int output = 0, 302*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape> inputShapes = {}); 303*89c4ff92SAndroid Build Coastguard Worker 304*89c4ff92SAndroid Build Coastguard Worker /// Settings for configuring the TfLiteParser 305*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options; 306*89c4ff92SAndroid Build Coastguard Worker 307*89c4ff92SAndroid Build Coastguard Worker /// The network we're building. Gets cleared after it is passed to the user 308*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_Network; 309*89c4ff92SAndroid Build Coastguard Worker ModelPtr m_Model; 310*89c4ff92SAndroid Build Coastguard Worker 311*89c4ff92SAndroid Build Coastguard Worker std::vector<OperatorParsingFunction> m_ParserFunctions; 312*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions; 313*89c4ff92SAndroid Build Coastguard Worker 314*89c4ff92SAndroid Build Coastguard Worker /// A mapping of an output slot to each of the input slots it should be connected to 315*89c4ff92SAndroid Build Coastguard Worker /// The outputSlot is from the layer that creates this tensor as one of its ouputs 316*89c4ff92SAndroid Build Coastguard Worker /// The inputSlots are from the layers that use this tensor as one of their inputs 317*89c4ff92SAndroid Build Coastguard Worker struct TensorSlots 318*89c4ff92SAndroid Build Coastguard Worker { 319*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* outputSlot; 320*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::IInputSlot*> inputSlots; 321*89c4ff92SAndroid Build Coastguard Worker TensorSlotsarmnnTfLiteParser::TfLiteParserImpl::TensorSlots322*89c4ff92SAndroid Build Coastguard Worker TensorSlots() : outputSlot(nullptr) { } 323*89c4ff92SAndroid Build Coastguard Worker }; 324*89c4ff92SAndroid Build Coastguard Worker typedef std::vector<TensorSlots> TensorConnections; 325*89c4ff92SAndroid Build Coastguard Worker /// Connections for tensors in each subgraph 326*89c4ff92SAndroid Build Coastguard Worker /// The first index is the subgraph ID, the second index is the tensor ID 327*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorConnections> m_SubgraphConnections; 328*89c4ff92SAndroid Build Coastguard Worker 329*89c4ff92SAndroid Build Coastguard Worker /// This is used in case that the model does not specify the output. 330*89c4ff92SAndroid Build Coastguard Worker /// The shape can be calculated from the options. 331*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<unsigned int>> m_OverriddenOutputShapes; 332*89c4ff92SAndroid Build Coastguard Worker 333*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> m_ConstantsToDequantize; 334*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> m_ConstantsToBeCreated; 335*89c4ff92SAndroid Build Coastguard Worker std::map<size_t, armnn::TensorInfo> m_TensorInfos; 336*89c4ff92SAndroid Build Coastguard Worker }; 337*89c4ff92SAndroid Build Coastguard Worker 338*89c4ff92SAndroid Build Coastguard Worker } 339