xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/Deserializer.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/INetwork.hpp>
9 #include <armnnDeserializer/IDeserializer.hpp>
10 #include <ArmnnSchema_generated.h>
11 
12 #include <unordered_map>
13 
14 namespace armnnDeserializer
15 {
16 
17 // Shorthands for deserializer types
18 using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
19 using GraphPtr = const armnnSerializer::SerializedGraph *;
20 using TensorRawPtr = const armnnSerializer::TensorInfo *;
21 using Pooling2dDescriptor = const armnnSerializer::Pooling2dDescriptor *;
22 using Pooling3dDescriptor = const armnnSerializer::Pooling3dDescriptor *;
23 using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
24 using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
25 using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
26 using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *;
27 using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *;
28 using TensorRawPtrVector = std::vector<TensorRawPtr>;
29 using LayerRawPtr = const armnnSerializer::LayerBase *;
30 using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
31 using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
32 using UnidirectionalSequenceLstmDescriptorPtr = const armnnSerializer::UnidirectionalSequenceLstmDescriptor *;
33 
34 class IDeserializer::DeserializerImpl
35 {
36 public:
37 
38     /// Create an input network from binary file contents
39     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
40 
41     /// Create an input network from a binary input stream
42     armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent);
43 
44     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
45     BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const;
46 
47     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
48     BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const;
49 
50     DeserializerImpl();
51     ~DeserializerImpl() = default;
52 
53     // No copying allowed until it is wanted and properly implemented
54     DeserializerImpl(const DeserializerImpl&) = delete;
55     DeserializerImpl& operator=(const DeserializerImpl&) = delete;
56 
57     // testable helpers
58     static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
59     static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
60     static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
61     static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
62     static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
63     static std::string GetLayerName(const GraphPtr& graph, unsigned int index);
64     static armnn::Pooling2dDescriptor GetPooling2dDescriptor(Pooling2dDescriptor pooling2dDescriptor,
65                                                            unsigned int layerIndex);
66     static armnn::Pooling3dDescriptor GetPooling3dDescriptor(Pooling3dDescriptor pooling3dDescriptor,
67                                                            unsigned int layerIndex);
68     static armnn::NormalizationDescriptor GetNormalizationDescriptor(
69         NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
70     static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
71     static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
72                                                      LstmInputParamsPtr lstmInputParams);
73     static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr);
74     static armnn::UnidirectionalSequenceLstmDescriptor GetUnidirectionalSequenceLstmDescriptor(
75         UnidirectionalSequenceLstmDescriptorPtr descriptor);
76     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
77                                                   const std::vector<uint32_t> & targetDimsIn);
78 
79 private:
80     /// Create the network from an already loaded flatbuffers graph
81     armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph);
82 
83     // signature for the parser functions
84     using LayerParsingFunction = void(DeserializerImpl::*)(GraphPtr graph, unsigned int layerIndex);
85 
86     void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex);
87     void ParseAbs(GraphPtr graph, unsigned int layerIndex);
88     void ParseActivation(GraphPtr graph, unsigned int layerIndex);
89     void ParseAdd(GraphPtr graph, unsigned int layerIndex);
90     void ParseArgMinMax(GraphPtr graph, unsigned int layerIndex);
91     void ParseBatchMatMul(GraphPtr graph, unsigned int layerIndex);
92     void ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex);
93     void ParseBatchNormalization(GraphPtr graph, unsigned int layerIndex);
94     void ParseCast(GraphPtr graph, unsigned int layerIndex);
95     void ParseChannelShuffle(GraphPtr graph, unsigned int layerIndex);
96     void ParseComparison(GraphPtr graph, unsigned int layerIndex);
97     void ParseConcat(GraphPtr graph, unsigned int layerIndex);
98     void ParseConstant(GraphPtr graph, unsigned int layerIndex);
99     void ParseConvolution2d(GraphPtr graph, unsigned int layerIndex);
100     void ParseConvolution3d(GraphPtr graph, unsigned int layerIndex);
101     void ParseDepthToSpace(GraphPtr graph, unsigned int layerIndex);
102     void ParseDepthwiseConvolution2d(GraphPtr graph, unsigned int layerIndex);
103     void ParseDequantize(GraphPtr graph, unsigned int layerIndex);
104     void ParseDetectionPostProcess(GraphPtr graph, unsigned int layerIndex);
105     void ParseDivision(GraphPtr graph, unsigned int layerIndex);
106     void ParseElementwiseBinary(GraphPtr graph, unsigned int layerIndex);
107     void ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex);
108     void ParseEqual(GraphPtr graph, unsigned int layerIndex);
109     void ParseFill(GraphPtr graph, unsigned int layerIndex);
110     void ParseFloor(GraphPtr graph, unsigned int layerIndex);
111     void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex);
112     void ParseGather(GraphPtr graph, unsigned int layerIndex);
113     void ParseGatherNd(GraphPtr graph, unsigned int layerIndex);
114     void ParseGreater(GraphPtr graph, unsigned int layerIndex);
115     void ParseInstanceNormalization(GraphPtr graph, unsigned int layerIndex);
116     void ParseL2Normalization(GraphPtr graph, unsigned int layerIndex);
117     void ParseLogicalBinary(GraphPtr graph, unsigned int layerIndex);
118     void ParseLogSoftmax(GraphPtr graph, unsigned int layerIndex);
119     void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
120     void ParseMean(GraphPtr graph, unsigned int layerIndex);
121     void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
122     void ParseMerge(GraphPtr graph, unsigned int layerIndex);
123     void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
124     void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
125     void ParseLstm(GraphPtr graph, unsigned int layerIndex);
126     void ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex);
127     void ParsePad(GraphPtr graph, unsigned int layerIndex);
128     void ParsePermute(GraphPtr graph, unsigned int layerIndex);
129     void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
130     void ParsePooling3d(GraphPtr graph, unsigned int layerIndex);
131     void ParsePrelu(GraphPtr graph, unsigned int layerIndex);
132     void ParseQLstm(GraphPtr graph, unsigned int layerIndex);
133     void ParseQuantize(GraphPtr graph, unsigned int layerIndex);
134     void ParseRank(GraphPtr graph, unsigned int layerIndex);
135     void ParseReduce(GraphPtr graph, unsigned int layerIndex);
136     void ParseReshape(GraphPtr graph, unsigned int layerIndex);
137     void ParseResize(GraphPtr graph, unsigned int layerIndex);
138     void ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex);
139     void ParseRsqrt(GraphPtr graph, unsigned int layerIndex);
140     void ParseShape(GraphPtr graph, unsigned int layerIndex);
141     void ParseSlice(GraphPtr graph, unsigned int layerIndex);
142     void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
143     void ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex);
144     void ParseSpaceToDepth(GraphPtr graph, unsigned int layerIndex);
145     void ParseSplitter(GraphPtr graph, unsigned int layerIndex);
146     void ParseStack(GraphPtr graph, unsigned int layerIndex);
147     void ParseStandIn(GraphPtr graph, unsigned int layerIndex);
148     void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
149     void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
150     void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
151     void ParseTranspose(GraphPtr graph, unsigned int layerIndex);
152     void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex);
153     void ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex);
154 
155     void RegisterInputSlots(GraphPtr graph,
156                             uint32_t layerIndex,
157                             armnn::IConnectableLayer* layer,
158                             std::vector<unsigned int> ignoreSlots = {});
159     void RegisterOutputSlots(GraphPtr graph,
160                              uint32_t layerIndex,
161                              armnn::IConnectableLayer* layer);
162 
163     // NOTE index here must be from flatbuffer object index property
164     void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IOutputSlot* slot);
165     void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
166 
167     void ResetParser();
168 
169     void SetupInputLayers(GraphPtr graphPtr);
170     void SetupOutputLayers(GraphPtr graphPtr);
171 
172     /// Helper to get the index of the layer in the flatbuffer vector from its bindingId property
173     unsigned int GetInputLayerInVector(GraphPtr graph, int targetId);
174     unsigned int GetOutputLayerInVector(GraphPtr graph, int targetId);
175 
176     /// Helper to get the index of the layer in the flatbuffer vector from its index property
177     unsigned int GetLayerIndexInVector(GraphPtr graph, unsigned int index);
178 
179     struct FeatureVersions
180     {
181         // Default values to zero for backward compatibility
182         unsigned int m_BindingIdScheme = 0;
183 
184         // Default values to zero for backward compatibility
185         unsigned int m_WeightsLayoutScheme = 0;
186 
187         // Default values to zero for backward compatibility
188         unsigned int m_ConstTensorsAsInputs = 0;
189     };
190 
191     FeatureVersions GetFeatureVersions(GraphPtr graph);
192 
193     /// The network we're building. Gets cleared after it is passed to the user
194     armnn::INetworkPtr                    m_Network;
195     std::vector<LayerParsingFunction>     m_ParserFunctions;
196 
197     using NameToBindingInfo = std::pair<std::string, BindingPointInfo >;
198     std::vector<NameToBindingInfo>    m_InputBindings;
199     std::vector<NameToBindingInfo>    m_OutputBindings;
200 
201     /// This struct describe connections for each layer
202     struct Connections
203     {
204         // Maps output slot index (property in flatbuffer object) to IOutputSlot pointer
205         std::unordered_map<unsigned int, armnn::IOutputSlot*> outputSlots;
206 
207         // Maps output slot index to IInputSlot pointer the output slot should be connected to
208         std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots;
209     };
210 
211     /// Maps layer index (index property in flatbuffer object) to Connections for each layer
212     std::unordered_map<unsigned int, Connections> m_GraphConnections;
213 };
214 
215 } // namespace armnnDeserializer