1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <ArmnnSchema_generated.h> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <unordered_map> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker namespace armnnDeserializer 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker // Shorthands for deserializer types 18*89c4ff92SAndroid Build Coastguard Worker using ConstTensorRawPtr = const armnnSerializer::ConstTensor *; 19*89c4ff92SAndroid Build Coastguard Worker using GraphPtr = const armnnSerializer::SerializedGraph *; 20*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtr = const armnnSerializer::TensorInfo *; 21*89c4ff92SAndroid Build Coastguard Worker using Pooling2dDescriptor = const armnnSerializer::Pooling2dDescriptor *; 22*89c4ff92SAndroid Build Coastguard Worker using Pooling3dDescriptor = const armnnSerializer::Pooling3dDescriptor *; 23*89c4ff92SAndroid Build Coastguard Worker using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *; 24*89c4ff92SAndroid Build Coastguard Worker using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *; 25*89c4ff92SAndroid Build Coastguard Worker using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *; 26*89c4ff92SAndroid Build Coastguard Worker using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *; 27*89c4ff92SAndroid Build Coastguard Worker using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *; 28*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtrVector = std::vector<TensorRawPtr>; 29*89c4ff92SAndroid Build Coastguard Worker using LayerRawPtr = const armnnSerializer::LayerBase *; 30*89c4ff92SAndroid Build Coastguard Worker using LayerBaseRawPtr = const armnnSerializer::LayerBase *; 31*89c4ff92SAndroid Build Coastguard Worker using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>; 32*89c4ff92SAndroid Build Coastguard Worker using UnidirectionalSequenceLstmDescriptorPtr = const armnnSerializer::UnidirectionalSequenceLstmDescriptor *; 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker class IDeserializer::DeserializerImpl 35*89c4ff92SAndroid Build Coastguard Worker { 36*89c4ff92SAndroid Build Coastguard Worker public: 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker /// Create an input network from binary file contents 39*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent); 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker /// Create an input network from a binary input stream 42*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent); 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name 45*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const; 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name 48*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const; 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker DeserializerImpl(); 51*89c4ff92SAndroid Build Coastguard Worker ~DeserializerImpl() = default; 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker // No copying allowed until it is wanted and properly implemented 54*89c4ff92SAndroid Build Coastguard Worker DeserializerImpl(const DeserializerImpl&) = delete; 55*89c4ff92SAndroid Build Coastguard Worker DeserializerImpl& operator=(const DeserializerImpl&) = delete; 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker // testable helpers 58*89c4ff92SAndroid Build Coastguard Worker static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len); 59*89c4ff92SAndroid Build Coastguard Worker static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex); 60*89c4ff92SAndroid Build Coastguard Worker static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex); 61*89c4ff92SAndroid Build Coastguard Worker static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex); 62*89c4ff92SAndroid Build Coastguard Worker static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex); 63*89c4ff92SAndroid Build Coastguard Worker static std::string GetLayerName(const GraphPtr& graph, unsigned int index); 64*89c4ff92SAndroid Build Coastguard Worker static armnn::Pooling2dDescriptor GetPooling2dDescriptor(Pooling2dDescriptor pooling2dDescriptor, 65*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex); 66*89c4ff92SAndroid Build Coastguard Worker static armnn::Pooling3dDescriptor GetPooling3dDescriptor(Pooling3dDescriptor pooling3dDescriptor, 67*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex); 68*89c4ff92SAndroid Build Coastguard Worker static armnn::NormalizationDescriptor GetNormalizationDescriptor( 69*89c4ff92SAndroid Build Coastguard Worker NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex); 70*89c4ff92SAndroid Build Coastguard Worker static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor); 71*89c4ff92SAndroid Build Coastguard Worker static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor, 72*89c4ff92SAndroid Build Coastguard Worker LstmInputParamsPtr lstmInputParams); 73*89c4ff92SAndroid Build Coastguard Worker static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr); 74*89c4ff92SAndroid Build Coastguard Worker static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor( 75*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmDescriptorPtr descriptor); 76*89c4ff92SAndroid Build Coastguard Worker static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, 77*89c4ff92SAndroid Build Coastguard Worker const std::vector<uint32_t> & targetDimsIn); 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker private: 80*89c4ff92SAndroid Build Coastguard Worker /// Create the network from an already loaded flatbuffers graph 81*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph); 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker // signature for the parser functions 84*89c4ff92SAndroid Build Coastguard Worker using LayerParsingFunction = void(DeserializerImpl::*)(GraphPtr graph, unsigned int layerIndex); 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex); 87*89c4ff92SAndroid Build Coastguard Worker void ParseAbs(GraphPtr graph, unsigned int layerIndex); 88*89c4ff92SAndroid Build Coastguard Worker void ParseActivation(GraphPtr graph, unsigned int layerIndex); 89*89c4ff92SAndroid Build Coastguard Worker void ParseAdd(GraphPtr graph, unsigned int layerIndex); 90*89c4ff92SAndroid Build Coastguard Worker void ParseArgMinMax(GraphPtr graph, unsigned int layerIndex); 91*89c4ff92SAndroid Build Coastguard Worker void ParseBatchMatMul(GraphPtr graph, unsigned int layerIndex); 92*89c4ff92SAndroid Build Coastguard Worker void ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex); 93*89c4ff92SAndroid Build Coastguard Worker void ParseBatchNormalization(GraphPtr graph, unsigned int layerIndex); 94*89c4ff92SAndroid Build Coastguard Worker void ParseCast(GraphPtr graph, unsigned int layerIndex); 95*89c4ff92SAndroid Build Coastguard Worker void ParseChannelShuffle(GraphPtr graph, unsigned int layerIndex); 96*89c4ff92SAndroid Build Coastguard Worker void ParseComparison(GraphPtr graph, unsigned int layerIndex); 97*89c4ff92SAndroid Build Coastguard Worker void ParseConcat(GraphPtr graph, unsigned int layerIndex); 98*89c4ff92SAndroid Build Coastguard Worker void ParseConstant(GraphPtr graph, unsigned int layerIndex); 99*89c4ff92SAndroid Build Coastguard Worker void ParseConvolution2d(GraphPtr graph, unsigned int layerIndex); 100*89c4ff92SAndroid Build Coastguard Worker void ParseConvolution3d(GraphPtr graph, unsigned int layerIndex); 101*89c4ff92SAndroid Build Coastguard Worker void ParseDepthToSpace(GraphPtr graph, unsigned int layerIndex); 102*89c4ff92SAndroid Build Coastguard Worker void ParseDepthwiseConvolution2d(GraphPtr graph, unsigned int layerIndex); 103*89c4ff92SAndroid Build Coastguard Worker void ParseDequantize(GraphPtr graph, unsigned int layerIndex); 104*89c4ff92SAndroid Build Coastguard Worker void ParseDetectionPostProcess(GraphPtr graph, unsigned int layerIndex); 105*89c4ff92SAndroid Build Coastguard Worker void ParseDivision(GraphPtr graph, unsigned int layerIndex); 106*89c4ff92SAndroid Build Coastguard Worker void ParseElementwiseBinary(GraphPtr graph, unsigned int layerIndex); 107*89c4ff92SAndroid Build Coastguard Worker void ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex); 108*89c4ff92SAndroid Build Coastguard Worker void ParseEqual(GraphPtr graph, unsigned int layerIndex); 109*89c4ff92SAndroid Build Coastguard Worker void ParseFill(GraphPtr graph, unsigned int layerIndex); 110*89c4ff92SAndroid Build Coastguard Worker void ParseFloor(GraphPtr graph, unsigned int layerIndex); 111*89c4ff92SAndroid Build Coastguard Worker void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex); 112*89c4ff92SAndroid Build Coastguard Worker void ParseGather(GraphPtr graph, unsigned int layerIndex); 113*89c4ff92SAndroid Build Coastguard Worker void ParseGatherNd(GraphPtr graph, unsigned int layerIndex); 114*89c4ff92SAndroid Build Coastguard Worker void ParseGreater(GraphPtr graph, unsigned int layerIndex); 115*89c4ff92SAndroid Build Coastguard Worker void ParseInstanceNormalization(GraphPtr graph, unsigned int layerIndex); 116*89c4ff92SAndroid Build Coastguard Worker void ParseL2Normalization(GraphPtr graph, unsigned int layerIndex); 117*89c4ff92SAndroid Build Coastguard Worker void ParseLogicalBinary(GraphPtr graph, unsigned int layerIndex); 118*89c4ff92SAndroid Build Coastguard Worker void ParseLogSoftmax(GraphPtr graph, unsigned int layerIndex); 119*89c4ff92SAndroid Build Coastguard Worker void ParseMaximum(GraphPtr graph, unsigned int layerIndex); 120*89c4ff92SAndroid Build Coastguard Worker void ParseMean(GraphPtr graph, unsigned int layerIndex); 121*89c4ff92SAndroid Build Coastguard Worker void ParseMinimum(GraphPtr graph, unsigned int layerIndex); 122*89c4ff92SAndroid Build Coastguard Worker void ParseMerge(GraphPtr graph, unsigned int layerIndex); 123*89c4ff92SAndroid Build Coastguard Worker void ParseMultiplication(GraphPtr graph, unsigned int layerIndex); 124*89c4ff92SAndroid Build Coastguard Worker void ParseNormalization(GraphPtr graph, unsigned int layerIndex); 125*89c4ff92SAndroid Build Coastguard Worker void ParseLstm(GraphPtr graph, unsigned int layerIndex); 126*89c4ff92SAndroid Build Coastguard Worker void ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex); 127*89c4ff92SAndroid Build Coastguard Worker void ParsePad(GraphPtr graph, unsigned int layerIndex); 128*89c4ff92SAndroid Build Coastguard Worker void ParsePermute(GraphPtr graph, unsigned int layerIndex); 129*89c4ff92SAndroid Build Coastguard Worker void ParsePooling2d(GraphPtr graph, unsigned int layerIndex); 130*89c4ff92SAndroid Build Coastguard Worker void ParsePooling3d(GraphPtr graph, unsigned int layerIndex); 131*89c4ff92SAndroid Build Coastguard Worker void ParsePrelu(GraphPtr graph, unsigned int layerIndex); 132*89c4ff92SAndroid Build Coastguard Worker void ParseQLstm(GraphPtr graph, unsigned int layerIndex); 133*89c4ff92SAndroid Build Coastguard Worker void ParseQuantize(GraphPtr graph, unsigned int layerIndex); 134*89c4ff92SAndroid Build Coastguard Worker void ParseRank(GraphPtr graph, unsigned int layerIndex); 135*89c4ff92SAndroid Build Coastguard Worker void ParseReduce(GraphPtr graph, unsigned int layerIndex); 136*89c4ff92SAndroid Build Coastguard Worker void ParseReshape(GraphPtr graph, unsigned int layerIndex); 137*89c4ff92SAndroid Build Coastguard Worker void ParseResize(GraphPtr graph, unsigned int layerIndex); 138*89c4ff92SAndroid Build Coastguard Worker void ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex); 139*89c4ff92SAndroid Build Coastguard Worker void ParseRsqrt(GraphPtr graph, unsigned int layerIndex); 140*89c4ff92SAndroid Build Coastguard Worker void ParseShape(GraphPtr graph, unsigned int layerIndex); 141*89c4ff92SAndroid Build Coastguard Worker void ParseSlice(GraphPtr graph, unsigned int layerIndex); 142*89c4ff92SAndroid Build Coastguard Worker void ParseSoftmax(GraphPtr graph, unsigned int layerIndex); 143*89c4ff92SAndroid Build Coastguard Worker void ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex); 144*89c4ff92SAndroid Build Coastguard Worker void ParseSpaceToDepth(GraphPtr graph, unsigned int layerIndex); 145*89c4ff92SAndroid Build Coastguard Worker void ParseSplitter(GraphPtr graph, unsigned int layerIndex); 146*89c4ff92SAndroid Build Coastguard Worker void ParseStack(GraphPtr graph, unsigned int layerIndex); 147*89c4ff92SAndroid Build Coastguard Worker void ParseStandIn(GraphPtr graph, unsigned int layerIndex); 148*89c4ff92SAndroid Build Coastguard Worker void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex); 149*89c4ff92SAndroid Build Coastguard Worker void ParseSubtraction(GraphPtr graph, unsigned int layerIndex); 150*89c4ff92SAndroid Build Coastguard Worker void ParseSwitch(GraphPtr graph, unsigned int layerIndex); 151*89c4ff92SAndroid Build Coastguard Worker void ParseTranspose(GraphPtr graph, unsigned int layerIndex); 152*89c4ff92SAndroid Build Coastguard Worker void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex); 153*89c4ff92SAndroid Build Coastguard Worker void ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex); 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker void RegisterInputSlots(GraphPtr graph, 156*89c4ff92SAndroid Build Coastguard Worker uint32_t layerIndex, 157*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer, 158*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {}); 159*89c4ff92SAndroid Build Coastguard Worker void RegisterOutputSlots(GraphPtr graph, 160*89c4ff92SAndroid Build Coastguard Worker uint32_t layerIndex, 161*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer); 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker // NOTE index here must be from flatbuffer object index property 164*89c4ff92SAndroid Build Coastguard Worker void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IOutputSlot* slot); 165*89c4ff92SAndroid Build Coastguard Worker void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker void ResetParser(); 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker void SetupInputLayers(GraphPtr graphPtr); 170*89c4ff92SAndroid Build Coastguard Worker void SetupOutputLayers(GraphPtr graphPtr); 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker /// Helper to get the index of the layer in the flatbuffer vector from its bindingId property 173*89c4ff92SAndroid Build Coastguard Worker unsigned int GetInputLayerInVector(GraphPtr graph, int targetId); 174*89c4ff92SAndroid Build Coastguard Worker unsigned int GetOutputLayerInVector(GraphPtr graph, int targetId); 175*89c4ff92SAndroid Build Coastguard Worker 176*89c4ff92SAndroid Build Coastguard Worker /// Helper to get the index of the layer in the flatbuffer vector from its index property 177*89c4ff92SAndroid Build Coastguard Worker unsigned int GetLayerIndexInVector(GraphPtr graph, unsigned int index); 178*89c4ff92SAndroid Build Coastguard Worker 179*89c4ff92SAndroid Build Coastguard Worker struct FeatureVersions 180*89c4ff92SAndroid Build Coastguard Worker { 181*89c4ff92SAndroid Build Coastguard Worker // Default values to zero for backward compatibility 182*89c4ff92SAndroid Build Coastguard Worker unsigned int m_BindingIdScheme = 0; 183*89c4ff92SAndroid Build Coastguard Worker 184*89c4ff92SAndroid Build Coastguard Worker // Default values to zero for backward compatibility 185*89c4ff92SAndroid Build Coastguard Worker unsigned int m_WeightsLayoutScheme = 0; 186*89c4ff92SAndroid Build Coastguard Worker 187*89c4ff92SAndroid Build Coastguard Worker // Default values to zero for backward compatibility 188*89c4ff92SAndroid Build Coastguard Worker unsigned int m_ConstTensorsAsInputs = 0; 189*89c4ff92SAndroid Build Coastguard Worker }; 190*89c4ff92SAndroid Build Coastguard Worker 191*89c4ff92SAndroid Build Coastguard Worker FeatureVersions GetFeatureVersions(GraphPtr graph); 192*89c4ff92SAndroid Build Coastguard Worker 193*89c4ff92SAndroid Build Coastguard Worker /// The network we're building. Gets cleared after it is passed to the user 194*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_Network; 195*89c4ff92SAndroid Build Coastguard Worker std::vector<LayerParsingFunction> m_ParserFunctions; 196*89c4ff92SAndroid Build Coastguard Worker 197*89c4ff92SAndroid Build Coastguard Worker using NameToBindingInfo = std::pair<std::string, BindingPointInfo >; 198*89c4ff92SAndroid Build Coastguard Worker std::vector<NameToBindingInfo> m_InputBindings; 199*89c4ff92SAndroid Build Coastguard Worker std::vector<NameToBindingInfo> m_OutputBindings; 200*89c4ff92SAndroid Build Coastguard Worker 201*89c4ff92SAndroid Build Coastguard Worker /// This struct describe connections for each layer 202*89c4ff92SAndroid Build Coastguard Worker struct Connections 203*89c4ff92SAndroid Build Coastguard Worker { 204*89c4ff92SAndroid Build Coastguard Worker // Maps output slot index (property in flatbuffer object) to IOutputSlot pointer 205*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<unsigned int, armnn::IOutputSlot*> outputSlots; 206*89c4ff92SAndroid Build Coastguard Worker 207*89c4ff92SAndroid Build Coastguard Worker // Maps output slot index to IInputSlot pointer the output slot should be connected to 208*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots; 209*89c4ff92SAndroid Build Coastguard Worker }; 210*89c4ff92SAndroid Build Coastguard Worker 211*89c4ff92SAndroid Build Coastguard Worker /// Maps layer index (index property in flatbuffer object) to Connections for each layer 212*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<unsigned int, Connections> m_GraphConnections; 213*89c4ff92SAndroid Build Coastguard Worker }; 214*89c4ff92SAndroid Build Coastguard Worker 215*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDeserializer