1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "Deserializer.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/LstmParams.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/QuantizedLstmParams.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Transpose.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker #include <ParserHelper.hpp>
22*89c4ff92SAndroid Build Coastguard Worker #include <VerificationHelpers.hpp>
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
27*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
28*89c4ff92SAndroid Build Coastguard Worker #include <limits>
29*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker using armnn::ParseException;
32*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
33*89c4ff92SAndroid Build Coastguard Worker using namespace armnnSerializer;
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker namespace armnnDeserializer
36*89c4ff92SAndroid Build Coastguard Worker {
37*89c4ff92SAndroid Build Coastguard Worker
IDeserializer()38*89c4ff92SAndroid Build Coastguard Worker IDeserializer::IDeserializer() : pDeserializerImpl(new DeserializerImpl()){}
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker IDeserializer::~IDeserializer() = default;
41*89c4ff92SAndroid Build Coastguard Worker
CreateRaw()42*89c4ff92SAndroid Build Coastguard Worker IDeserializer *IDeserializer::CreateRaw()
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker return new IDeserializer();
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker
Create()47*89c4ff92SAndroid Build Coastguard Worker IDeserializerPtr IDeserializer::Create()
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker return IDeserializerPtr(CreateRaw(), &IDeserializer::Destroy);
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
Destroy(IDeserializer * parser)52*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::Destroy(IDeserializer *parser)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker delete parser;
55*89c4ff92SAndroid Build Coastguard Worker }
56*89c4ff92SAndroid Build Coastguard Worker
CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent)57*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr IDeserializer::CreateNetworkFromBinary(const std::vector<uint8_t> &binaryContent)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker return pDeserializerImpl->CreateNetworkFromBinary(binaryContent);
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker
CreateNetworkFromBinary(std::istream & binaryContent)62*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr IDeserializer::CreateNetworkFromBinary(std::istream &binaryContent)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker return pDeserializerImpl->CreateNetworkFromBinary(binaryContent);
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker
GetNetworkInputBindingInfo(unsigned int layerId,const std::string & name) const67*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo IDeserializer::GetNetworkInputBindingInfo(unsigned int layerId, const std::string &name) const
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker return pDeserializerImpl->GetNetworkInputBindingInfo(layerId, name);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
GetNetworkOutputBindingInfo(unsigned int layerId,const std::string & name) const72*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo IDeserializer::GetNetworkOutputBindingInfo(unsigned int layerId, const std::string &name) const
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker return pDeserializerImpl->GetNetworkOutputBindingInfo(layerId, name);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker namespace
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker const uint32_t VIRTUAL_LAYER_ID = std::numeric_limits<uint32_t>::max();
81*89c4ff92SAndroid Build Coastguard Worker
CheckGraph(const GraphPtr & graph,unsigned int layersIndex,const CheckLocation & location)82*89c4ff92SAndroid Build Coastguard Worker void CheckGraph(const GraphPtr& graph,
83*89c4ff92SAndroid Build Coastguard Worker unsigned int layersIndex,
84*89c4ff92SAndroid Build Coastguard Worker const CheckLocation& location)
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker if (graph->layers() == nullptr)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with invalid (null) graph. "
89*89c4ff92SAndroid Build Coastguard Worker "Possible reason is that the graph is not yet loaded and Unpack(ed). "
90*89c4ff92SAndroid Build Coastguard Worker "layers:{1} at {2}",
91*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
92*89c4ff92SAndroid Build Coastguard Worker layersIndex,
93*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker else if (layersIndex >= graph->layers()->size())
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with an invalid layers index. layers:{1} at {2}",
98*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
99*89c4ff92SAndroid Build Coastguard Worker layersIndex,
100*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker
CheckLayers(const GraphPtr & graph,unsigned int layersIndex,unsigned int layerIndex,const CheckLocation & location)104*89c4ff92SAndroid Build Coastguard Worker void CheckLayers(const GraphPtr& graph,
105*89c4ff92SAndroid Build Coastguard Worker unsigned int layersIndex,
106*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex,
107*89c4ff92SAndroid Build Coastguard Worker const CheckLocation& location)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker if (graph->layers() == nullptr)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with invalid (null) graph. "
112*89c4ff92SAndroid Build Coastguard Worker "Possible reason is that the graph is not yet loaded and Unpack(ed). "
113*89c4ff92SAndroid Build Coastguard Worker "layers:{1} at {2}",
114*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
115*89c4ff92SAndroid Build Coastguard Worker layersIndex,
116*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker else if (layersIndex >= graph->layers()->size())
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with an invalid layers index. "
121*89c4ff92SAndroid Build Coastguard Worker "layers:{1} at {2}",
122*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
123*89c4ff92SAndroid Build Coastguard Worker layersIndex,
124*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker else if (layerIndex >= graph->layers()[layersIndex].size()
127*89c4ff92SAndroid Build Coastguard Worker && layerIndex != VIRTUAL_LAYER_ID)
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with an invalid layer index. "
130*89c4ff92SAndroid Build Coastguard Worker "layers:{1} layer:{2} at {3}",
131*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
132*89c4ff92SAndroid Build Coastguard Worker layersIndex,
133*89c4ff92SAndroid Build Coastguard Worker layerIndex,
134*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker
CheckTensorPtr(TensorRawPtr rawPtr,const CheckLocation & location)138*89c4ff92SAndroid Build Coastguard Worker void CheckTensorPtr(TensorRawPtr rawPtr,
139*89c4ff92SAndroid Build Coastguard Worker const CheckLocation& location)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker if (rawPtr == nullptr)
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with a null tensor pointer. at {1}",
144*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
145*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker
CheckConstTensorPtr(ConstTensorRawPtr rawPtr,const CheckLocation & location)149*89c4ff92SAndroid Build Coastguard Worker void CheckConstTensorPtr(ConstTensorRawPtr rawPtr,
150*89c4ff92SAndroid Build Coastguard Worker const CheckLocation& location)
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker if (rawPtr == nullptr)
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} was called with a null const tensor pointer. at {1}",
155*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
156*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker
CheckConstTensorSize(const unsigned int constTensorSize,const unsigned int tensorSize,const CheckLocation & location)160*89c4ff92SAndroid Build Coastguard Worker void CheckConstTensorSize(const unsigned int constTensorSize,
161*89c4ff92SAndroid Build Coastguard Worker const unsigned int tensorSize,
162*89c4ff92SAndroid Build Coastguard Worker const CheckLocation& location)
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker if (constTensorSize != tensorSize)
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("{0} wrong number of components supplied to tensor. at:{1}",
167*89c4ff92SAndroid Build Coastguard Worker location.m_Function,
168*89c4ff92SAndroid Build Coastguard Worker location.FileLine()));
169*89c4ff92SAndroid Build Coastguard Worker }
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker
172*89c4ff92SAndroid Build Coastguard Worker #define CHECK_TENSOR_PTR(TENSOR_PTR) \
173*89c4ff92SAndroid Build Coastguard Worker CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION())
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker #define CHECK_CONST_TENSOR_SIZE(CONST_TENSOR_SIZE, TENSOR_SIZE) \
176*89c4ff92SAndroid Build Coastguard Worker CheckConstTensorSize(CONST_TENSOR_SIZE, TENSOR_SIZE, CHECK_LOCATION())
177*89c4ff92SAndroid Build Coastguard Worker
178*89c4ff92SAndroid Build Coastguard Worker #define CHECK_CONST_TENSOR_PTR(TENSOR_PTR) \
179*89c4ff92SAndroid Build Coastguard Worker CheckConstTensorPtr(TENSOR_PTR, CHECK_LOCATION())
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker #define CHECK_LAYERS(GRAPH, LAYERS_INDEX, LAYER_INDEX) \
182*89c4ff92SAndroid Build Coastguard Worker CheckLayers(GRAPH, LAYERS_INDEX, LAYER_INDEX, CHECK_LOCATION())
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker #define CHECK_GRAPH(GRAPH, LAYERS_INDEX) \
185*89c4ff92SAndroid Build Coastguard Worker CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION())
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker
CheckShape(const armnn::TensorShape & actual,const std::vector<uint32_t> & expected)188*89c4ff92SAndroid Build Coastguard Worker bool CheckShape(const armnn::TensorShape& actual, const std::vector<uint32_t>& expected)
189*89c4ff92SAndroid Build Coastguard Worker {
190*89c4ff92SAndroid Build Coastguard Worker const unsigned int actualSize = actual.GetNumDimensions();
191*89c4ff92SAndroid Build Coastguard Worker if (actualSize != expected.size())
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker return false;
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker
196*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < actualSize; i++)
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker if (actual[i] != static_cast<unsigned int>(expected[i]))
199*89c4ff92SAndroid Build Coastguard Worker {
200*89c4ff92SAndroid Build Coastguard Worker return false;
201*89c4ff92SAndroid Build Coastguard Worker }
202*89c4ff92SAndroid Build Coastguard Worker }
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker return true;
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker
DeserializerImpl()207*89c4ff92SAndroid Build Coastguard Worker IDeserializer::DeserializerImpl::DeserializerImpl()
208*89c4ff92SAndroid Build Coastguard Worker : m_Network(nullptr, nullptr),
209*89c4ff92SAndroid Build Coastguard Worker //May require LayerType_Max to be included
210*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupportedLayer)
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker // register supported layers
213*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_AbsLayer] = &DeserializerImpl::ParseAbs;
214*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ActivationLayer] = &DeserializerImpl::ParseActivation;
215*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_AdditionLayer] = &DeserializerImpl::ParseAdd;
216*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ArgMinMaxLayer] = &DeserializerImpl::ParseArgMinMax;
217*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_BatchMatMulLayer] = &DeserializerImpl::ParseBatchMatMul;
218*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &DeserializerImpl::ParseBatchToSpaceNd;
219*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_BatchNormalizationLayer] = &DeserializerImpl::ParseBatchNormalization;
220*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_CastLayer] = &DeserializerImpl::ParseCast;
221*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ChannelShuffleLayer] = &DeserializerImpl::ParseChannelShuffle;
222*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ComparisonLayer] = &DeserializerImpl::ParseComparison;
223*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ConcatLayer] = &DeserializerImpl::ParseConcat;
224*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ConstantLayer] = &DeserializerImpl::ParseConstant;
225*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_Convolution2dLayer] = &DeserializerImpl::ParseConvolution2d;
226*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_Convolution3dLayer] = &DeserializerImpl::ParseConvolution3d;
227*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_DepthToSpaceLayer] = &DeserializerImpl::ParseDepthToSpace;
228*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_DepthwiseConvolution2dLayer] = &DeserializerImpl::ParseDepthwiseConvolution2d;
229*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_DequantizeLayer] = &DeserializerImpl::ParseDequantize;
230*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_DetectionPostProcessLayer] = &DeserializerImpl::ParseDetectionPostProcess;
231*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_DivisionLayer] = &DeserializerImpl::ParseDivision;
232*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ElementwiseBinaryLayer] = &DeserializerImpl::ParseElementwiseBinary;
233*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ElementwiseUnaryLayer] = &DeserializerImpl::ParseElementwiseUnary;
234*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_EqualLayer] = &DeserializerImpl::ParseEqual;
235*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_FullyConnectedLayer] = &DeserializerImpl::ParseFullyConnected;
236*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_FillLayer] = &DeserializerImpl::ParseFill;
237*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_FloorLayer] = &DeserializerImpl::ParseFloor;
238*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_GatherLayer] = &DeserializerImpl::ParseGather;
239*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_GatherNdLayer] = &DeserializerImpl::ParseGatherNd;
240*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_GreaterLayer] = &DeserializerImpl::ParseGreater;
241*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_InstanceNormalizationLayer] = &DeserializerImpl::ParseInstanceNormalization;
242*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_L2NormalizationLayer] = &DeserializerImpl::ParseL2Normalization;
243*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_LogicalBinaryLayer] = &DeserializerImpl::ParseLogicalBinary;
244*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_LogSoftmaxLayer] = &DeserializerImpl::ParseLogSoftmax;
245*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_LstmLayer] = &DeserializerImpl::ParseLstm;
246*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_MaximumLayer] = &DeserializerImpl::ParseMaximum;
247*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_MeanLayer] = &DeserializerImpl::ParseMean;
248*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_MinimumLayer] = &DeserializerImpl::ParseMinimum;
249*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_MergeLayer] = &DeserializerImpl::ParseMerge;
250*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_MergerLayer] = &DeserializerImpl::ParseConcat;
251*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializerImpl::ParseMultiplication;
252*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_NormalizationLayer] = &DeserializerImpl::ParseNormalization;
253*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_PadLayer] = &DeserializerImpl::ParsePad;
254*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_PermuteLayer] = &DeserializerImpl::ParsePermute;
255*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializerImpl::ParsePooling2d;
256*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_Pooling3dLayer] = &DeserializerImpl::ParsePooling3d;
257*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_PreluLayer] = &DeserializerImpl::ParsePrelu;
258*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_QLstmLayer] = &DeserializerImpl::ParseQLstm;
259*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_QuantizeLayer] = &DeserializerImpl::ParseQuantize;
260*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_QuantizedLstmLayer] = &DeserializerImpl::ParseQuantizedLstm;
261*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_RankLayer] = &DeserializerImpl::ParseRank;
262*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ReduceLayer] = &DeserializerImpl::ParseReduce;
263*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ReshapeLayer] = &DeserializerImpl::ParseReshape;
264*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ResizeBilinearLayer] = &DeserializerImpl::ParseResizeBilinear;
265*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ResizeLayer] = &DeserializerImpl::ParseResize;
266*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_RsqrtLayer] = &DeserializerImpl::ParseRsqrt;
267*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_ShapeLayer] = &DeserializerImpl::ParseShape;
268*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SliceLayer] = &DeserializerImpl::ParseSlice;
269*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializerImpl::ParseSoftmax;
270*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SpaceToBatchNdLayer] = &DeserializerImpl::ParseSpaceToBatchNd;
271*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SpaceToDepthLayer] = &DeserializerImpl::ParseSpaceToDepth;
272*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SplitterLayer] = &DeserializerImpl::ParseSplitter;
273*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_StackLayer] = &DeserializerImpl::ParseStack;
274*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_StandInLayer] = &DeserializerImpl::ParseStandIn;
275*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_StridedSliceLayer] = &DeserializerImpl::ParseStridedSlice;
276*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SubtractionLayer] = &DeserializerImpl::ParseSubtraction;
277*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_SwitchLayer] = &DeserializerImpl::ParseSwitch;
278*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &DeserializerImpl::ParseTransposeConvolution2d;
279*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_TransposeLayer] = &DeserializerImpl::ParseTranspose;
280*89c4ff92SAndroid Build Coastguard Worker m_ParserFunctions[Layer_UnidirectionalSequenceLstmLayer] = &DeserializerImpl::ParseUnidirectionalSequenceLstm;
281*89c4ff92SAndroid Build Coastguard Worker }
282*89c4ff92SAndroid Build Coastguard Worker
GetBaseLayer(const GraphPtr & graphPtr,unsigned int layerIndex)283*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex)
284*89c4ff92SAndroid Build Coastguard Worker {
285*89c4ff92SAndroid Build Coastguard Worker auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
286*89c4ff92SAndroid Build Coastguard Worker
287*89c4ff92SAndroid Build Coastguard Worker switch(layerType)
288*89c4ff92SAndroid Build Coastguard Worker {
289*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_AbsLayer:
290*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_AbsLayer()->base();
291*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ActivationLayer:
292*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ActivationLayer()->base();
293*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_AdditionLayer:
294*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_AdditionLayer()->base();
295*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ArgMinMaxLayer:
296*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ArgMinMaxLayer()->base();
297*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_BatchMatMulLayer:
298*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_BatchMatMulLayer()->base();
299*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_BatchToSpaceNdLayer:
300*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->base();
301*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_BatchNormalizationLayer:
302*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base();
303*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_CastLayer:
304*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_CastLayer()->base();
305*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ChannelShuffleLayer:
306*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->base();
307*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ComparisonLayer:
308*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ComparisonLayer()->base();
309*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ConcatLayer:
310*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ConcatLayer()->base();
311*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ConstantLayer:
312*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ConstantLayer()->base();
313*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_Convolution2dLayer:
314*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_Convolution2dLayer()->base();
315*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_Convolution3dLayer:
316*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_Convolution3dLayer()->base();
317*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_DepthToSpaceLayer:
318*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_DepthToSpaceLayer()->base();
319*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_DepthwiseConvolution2dLayer:
320*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_DepthwiseConvolution2dLayer()->base();
321*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_DequantizeLayer:
322*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_DequantizeLayer()->base();
323*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_DetectionPostProcessLayer:
324*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_DetectionPostProcessLayer()->base();
325*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_DivisionLayer:
326*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_DivisionLayer()->base();
327*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_EqualLayer:
328*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_EqualLayer()->base();
329*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ElementwiseBinaryLayer:
330*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseBinaryLayer()->base();
331*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ElementwiseUnaryLayer:
332*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseUnaryLayer()->base();
333*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_FullyConnectedLayer:
334*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_FullyConnectedLayer()->base();
335*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_FillLayer:
336*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_FillLayer()->base();
337*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_FloorLayer:
338*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_FloorLayer()->base();
339*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_GatherLayer:
340*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_GatherLayer()->base();
341*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_GatherNdLayer:
342*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_GatherNdLayer()->base();
343*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_GreaterLayer:
344*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_GreaterLayer()->base();
345*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_InputLayer:
346*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base();
347*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_InstanceNormalizationLayer:
348*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_InstanceNormalizationLayer()->base();
349*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_L2NormalizationLayer:
350*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer()->base();
351*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_LogicalBinaryLayer:
352*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_LogicalBinaryLayer()->base();
353*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_LogSoftmaxLayer:
354*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_LogSoftmaxLayer()->base();
355*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_LstmLayer:
356*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_LstmLayer()->base();
357*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MeanLayer:
358*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_MeanLayer()->base();
359*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MinimumLayer:
360*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base();
361*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MaximumLayer:
362*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base();
363*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MergeLayer:
364*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base();
365*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MergerLayer:
366*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base();
367*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MultiplicationLayer:
368*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base();
369*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_NormalizationLayer:
370*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_NormalizationLayer()->base();
371*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_OutputLayer:
372*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
373*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_PadLayer:
374*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_PadLayer()->base();
375*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_PermuteLayer:
376*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_PermuteLayer()->base();
377*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_Pooling2dLayer:
378*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
379*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_Pooling3dLayer:
380*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling3dLayer()->base();
381*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_PreluLayer:
382*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_PreluLayer()->base();
383*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_QLstmLayer:
384*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_QLstmLayer()->base();
385*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_QuantizeLayer:
386*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizeLayer()->base();
387*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_QuantizedLstmLayer:
388*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_QuantizedLstmLayer()->base();
389*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_RankLayer:
390*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_RankLayer()->base();
391*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ReduceLayer:
392*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ReduceLayer()->base();
393*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ReshapeLayer:
394*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base();
395*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ResizeBilinearLayer:
396*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeBilinearLayer()->base();
397*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ResizeLayer:
398*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ResizeLayer()->base();
399*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_RsqrtLayer:
400*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_RsqrtLayer()->base();
401*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ShapeLayer:
402*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_ShapeLayer()->base();
403*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SliceLayer:
404*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SliceLayer()->base();
405*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SoftmaxLayer:
406*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
407*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SpaceToBatchNdLayer:
408*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToBatchNdLayer()->base();
409*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SpaceToDepthLayer:
410*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SpaceToDepthLayer()->base();
411*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SplitterLayer:
412*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SplitterLayer()->base();
413*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_StackLayer:
414*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_StackLayer()->base();
415*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_StandInLayer:
416*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_StandInLayer()->base();
417*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_StridedSliceLayer:
418*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base();
419*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SubtractionLayer:
420*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base();
421*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_SwitchLayer:
422*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base();
423*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_TransposeConvolution2dLayer:
424*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base();
425*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_TransposeLayer:
426*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base();
427*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_UnidirectionalSequenceLstmLayer:
428*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer()->base();
429*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_NONE:
430*89c4ff92SAndroid Build Coastguard Worker default:
431*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("Layer type {} not recognized", layerType));
432*89c4ff92SAndroid Build Coastguard Worker }
433*89c4ff92SAndroid Build Coastguard Worker }
434*89c4ff92SAndroid Build Coastguard Worker
GetLayerName(const GraphPtr & graph,unsigned int index)435*89c4ff92SAndroid Build Coastguard Worker std::string IDeserializer::DeserializerImpl::GetLayerName(const GraphPtr& graph, unsigned int index)
436*89c4ff92SAndroid Build Coastguard Worker {
437*89c4ff92SAndroid Build Coastguard Worker auto layer = GetBaseLayer(graph, index);
438*89c4ff92SAndroid Build Coastguard Worker assert(layer);
439*89c4ff92SAndroid Build Coastguard Worker return layer->layerName()->str();
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker
GetBindingLayerInfo(const GraphPtr & graphPtr,unsigned int layerIndex)442*89c4ff92SAndroid Build Coastguard Worker int32_t IDeserializer::DeserializerImpl::GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex)
443*89c4ff92SAndroid Build Coastguard Worker {
444*89c4ff92SAndroid Build Coastguard Worker auto layerType = graphPtr->layers()->Get(layerIndex)->layer_type();
445*89c4ff92SAndroid Build Coastguard Worker
446*89c4ff92SAndroid Build Coastguard Worker if (layerType == Layer::Layer_InputLayer)
447*89c4ff92SAndroid Build Coastguard Worker {
448*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->layerBindingId();
449*89c4ff92SAndroid Build Coastguard Worker }
450*89c4ff92SAndroid Build Coastguard Worker else if ( layerType == Layer::Layer_OutputLayer )
451*89c4ff92SAndroid Build Coastguard Worker {
452*89c4ff92SAndroid Build Coastguard Worker return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->layerBindingId();
453*89c4ff92SAndroid Build Coastguard Worker }
454*89c4ff92SAndroid Build Coastguard Worker return 0;
455*89c4ff92SAndroid Build Coastguard Worker }
456*89c4ff92SAndroid Build Coastguard Worker
ToDataLayout(armnnSerializer::DataLayout dataLayout)457*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout ToDataLayout(armnnSerializer::DataLayout dataLayout)
458*89c4ff92SAndroid Build Coastguard Worker {
459*89c4ff92SAndroid Build Coastguard Worker switch (dataLayout)
460*89c4ff92SAndroid Build Coastguard Worker {
461*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::DataLayout::DataLayout_NHWC:
462*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NHWC;
463*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::DataLayout::DataLayout_NDHWC:
464*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NDHWC;
465*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::DataLayout::DataLayout_NCDHW:
466*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NCDHW;
467*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::DataLayout::DataLayout_NCHW:
468*89c4ff92SAndroid Build Coastguard Worker default:
469*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NCHW;
470*89c4ff92SAndroid Build Coastguard Worker }
471*89c4ff92SAndroid Build Coastguard Worker }
472*89c4ff92SAndroid Build Coastguard Worker
ToActivationFunction(armnnSerializer::ActivationFunction function)473*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationFunction ToActivationFunction(armnnSerializer::ActivationFunction function)
474*89c4ff92SAndroid Build Coastguard Worker {
475*89c4ff92SAndroid Build Coastguard Worker switch (function)
476*89c4ff92SAndroid Build Coastguard Worker {
477*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_Sigmoid:
478*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Sigmoid;
479*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_TanH:
480*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::TanH;
481*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_Linear:
482*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Linear;
483*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_ReLu:
484*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::ReLu;
485*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_BoundedReLu:
486*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::BoundedReLu;
487*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_LeakyReLu:
488*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::LeakyReLu;
489*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_Abs:
490*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Abs;
491*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_Sqrt:
492*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Sqrt;
493*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_Square:
494*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Square;
495*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_Elu:
496*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Elu;
497*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ActivationFunction_HardSwish:
498*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::HardSwish;
499*89c4ff92SAndroid Build Coastguard Worker default:
500*89c4ff92SAndroid Build Coastguard Worker return armnn::ActivationFunction::Sigmoid;
501*89c4ff92SAndroid Build Coastguard Worker }
502*89c4ff92SAndroid Build Coastguard Worker }
503*89c4ff92SAndroid Build Coastguard Worker
ToArgMinMaxFunction(armnnSerializer::ArgMinMaxFunction function)504*89c4ff92SAndroid Build Coastguard Worker armnn::ArgMinMaxFunction ToArgMinMaxFunction(armnnSerializer::ArgMinMaxFunction function)
505*89c4ff92SAndroid Build Coastguard Worker {
506*89c4ff92SAndroid Build Coastguard Worker switch (function)
507*89c4ff92SAndroid Build Coastguard Worker {
508*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ArgMinMaxFunction::ArgMinMaxFunction_Max:
509*89c4ff92SAndroid Build Coastguard Worker return armnn::ArgMinMaxFunction::Max;
510*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ArgMinMaxFunction::ArgMinMaxFunction_Min:
511*89c4ff92SAndroid Build Coastguard Worker default:
512*89c4ff92SAndroid Build Coastguard Worker return armnn::ArgMinMaxFunction::Min;
513*89c4ff92SAndroid Build Coastguard Worker }
514*89c4ff92SAndroid Build Coastguard Worker }
515*89c4ff92SAndroid Build Coastguard Worker
ToComparisonOperation(armnnSerializer::ComparisonOperation operation)516*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonOperation ToComparisonOperation(armnnSerializer::ComparisonOperation operation)
517*89c4ff92SAndroid Build Coastguard Worker {
518*89c4ff92SAndroid Build Coastguard Worker switch (operation)
519*89c4ff92SAndroid Build Coastguard Worker {
520*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ComparisonOperation::ComparisonOperation_Equal:
521*89c4ff92SAndroid Build Coastguard Worker return armnn::ComparisonOperation::Equal;
522*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ComparisonOperation::ComparisonOperation_Greater:
523*89c4ff92SAndroid Build Coastguard Worker return armnn::ComparisonOperation::Greater;
524*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ComparisonOperation::ComparisonOperation_GreaterOrEqual:
525*89c4ff92SAndroid Build Coastguard Worker return armnn::ComparisonOperation::GreaterOrEqual;
526*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ComparisonOperation::ComparisonOperation_Less:
527*89c4ff92SAndroid Build Coastguard Worker return armnn::ComparisonOperation::Less;
528*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ComparisonOperation::ComparisonOperation_LessOrEqual:
529*89c4ff92SAndroid Build Coastguard Worker return armnn::ComparisonOperation::LessOrEqual;
530*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ComparisonOperation::ComparisonOperation_NotEqual:
531*89c4ff92SAndroid Build Coastguard Worker default:
532*89c4ff92SAndroid Build Coastguard Worker return armnn::ComparisonOperation::NotEqual;
533*89c4ff92SAndroid Build Coastguard Worker }
534*89c4ff92SAndroid Build Coastguard Worker }
535*89c4ff92SAndroid Build Coastguard Worker
ToReduceOperation(armnnSerializer::ReduceOperation operation)536*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceOperation ToReduceOperation(armnnSerializer::ReduceOperation operation)
537*89c4ff92SAndroid Build Coastguard Worker {
538*89c4ff92SAndroid Build Coastguard Worker switch (operation)
539*89c4ff92SAndroid Build Coastguard Worker {
540*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ReduceOperation::ReduceOperation_Sum:
541*89c4ff92SAndroid Build Coastguard Worker return armnn::ReduceOperation::Sum;
542*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ReduceOperation::ReduceOperation_Max:
543*89c4ff92SAndroid Build Coastguard Worker return armnn::ReduceOperation::Max;
544*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ReduceOperation::ReduceOperation_Mean:
545*89c4ff92SAndroid Build Coastguard Worker return armnn::ReduceOperation::Mean;
546*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ReduceOperation::ReduceOperation_Min:
547*89c4ff92SAndroid Build Coastguard Worker return armnn::ReduceOperation::Min;
548*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ReduceOperation::ReduceOperation_Prod:
549*89c4ff92SAndroid Build Coastguard Worker return armnn::ReduceOperation::Prod;
550*89c4ff92SAndroid Build Coastguard Worker default:
551*89c4ff92SAndroid Build Coastguard Worker return armnn::ReduceOperation::Sum;
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker }
554*89c4ff92SAndroid Build Coastguard Worker
ToLogicalBinaryOperation(armnnSerializer::LogicalBinaryOperation operation)555*89c4ff92SAndroid Build Coastguard Worker armnn::LogicalBinaryOperation ToLogicalBinaryOperation(armnnSerializer::LogicalBinaryOperation operation)
556*89c4ff92SAndroid Build Coastguard Worker {
557*89c4ff92SAndroid Build Coastguard Worker switch (operation)
558*89c4ff92SAndroid Build Coastguard Worker {
559*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::LogicalBinaryOperation::LogicalBinaryOperation_LogicalAnd:
560*89c4ff92SAndroid Build Coastguard Worker return armnn::LogicalBinaryOperation::LogicalAnd;
561*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::LogicalBinaryOperation::LogicalBinaryOperation_LogicalOr:
562*89c4ff92SAndroid Build Coastguard Worker return armnn::LogicalBinaryOperation::LogicalOr;
563*89c4ff92SAndroid Build Coastguard Worker default:
564*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException("Logical Binary operation unknown");
565*89c4ff92SAndroid Build Coastguard Worker }
566*89c4ff92SAndroid Build Coastguard Worker }
567*89c4ff92SAndroid Build Coastguard Worker
ToElementwiseBinaryOperation(armnnSerializer::BinaryOperation operation)568*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation ToElementwiseBinaryOperation(armnnSerializer::BinaryOperation operation)
569*89c4ff92SAndroid Build Coastguard Worker {
570*89c4ff92SAndroid Build Coastguard Worker switch (operation)
571*89c4ff92SAndroid Build Coastguard Worker {
572*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::BinaryOperation::BinaryOperation_Add:
573*89c4ff92SAndroid Build Coastguard Worker return armnn::BinaryOperation::Add;
574*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::BinaryOperation::BinaryOperation_Div:
575*89c4ff92SAndroid Build Coastguard Worker return armnn::BinaryOperation::Div;
576*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::BinaryOperation::BinaryOperation_Maximum:
577*89c4ff92SAndroid Build Coastguard Worker return armnn::BinaryOperation::Maximum;
578*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::BinaryOperation::BinaryOperation_Minimum:
579*89c4ff92SAndroid Build Coastguard Worker return armnn::BinaryOperation::Minimum;
580*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::BinaryOperation::BinaryOperation_Mul:
581*89c4ff92SAndroid Build Coastguard Worker return armnn::BinaryOperation::Mul;
582*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::BinaryOperation::BinaryOperation_Sub:
583*89c4ff92SAndroid Build Coastguard Worker return armnn::BinaryOperation::Sub;
584*89c4ff92SAndroid Build Coastguard Worker default:
585*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException("Binary operation unknown");
586*89c4ff92SAndroid Build Coastguard Worker }
587*89c4ff92SAndroid Build Coastguard Worker }
588*89c4ff92SAndroid Build Coastguard Worker
ToElementwiseUnaryOperation(armnnSerializer::UnaryOperation operation)589*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation ToElementwiseUnaryOperation(armnnSerializer::UnaryOperation operation)
590*89c4ff92SAndroid Build Coastguard Worker {
591*89c4ff92SAndroid Build Coastguard Worker switch (operation)
592*89c4ff92SAndroid Build Coastguard Worker {
593*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Abs:
594*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Abs;
595*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Ceil:
596*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Ceil;
597*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Rsqrt:
598*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Rsqrt;
599*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Sqrt:
600*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Sqrt;
601*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Exp:
602*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Exp;
603*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Neg:
604*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Neg;
605*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_LogicalNot:
606*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::LogicalNot;
607*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Log:
608*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Log;
609*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::UnaryOperation::UnaryOperation_Sin:
610*89c4ff92SAndroid Build Coastguard Worker return armnn::UnaryOperation::Sin;
611*89c4ff92SAndroid Build Coastguard Worker default:
612*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException("Unary operation unknown");
613*89c4ff92SAndroid Build Coastguard Worker }
614*89c4ff92SAndroid Build Coastguard Worker }
615*89c4ff92SAndroid Build Coastguard Worker
ToPaddingMode(armnnSerializer::PaddingMode paddingMode)616*89c4ff92SAndroid Build Coastguard Worker armnn::PaddingMode ToPaddingMode(armnnSerializer::PaddingMode paddingMode)
617*89c4ff92SAndroid Build Coastguard Worker {
618*89c4ff92SAndroid Build Coastguard Worker switch (paddingMode)
619*89c4ff92SAndroid Build Coastguard Worker {
620*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::PaddingMode::PaddingMode_Reflect:
621*89c4ff92SAndroid Build Coastguard Worker return armnn::PaddingMode::Reflect;
622*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::PaddingMode::PaddingMode_Symmetric:
623*89c4ff92SAndroid Build Coastguard Worker return armnn::PaddingMode::Symmetric;
624*89c4ff92SAndroid Build Coastguard Worker default:
625*89c4ff92SAndroid Build Coastguard Worker return armnn::PaddingMode::Constant;
626*89c4ff92SAndroid Build Coastguard Worker }
627*89c4ff92SAndroid Build Coastguard Worker }
628*89c4ff92SAndroid Build Coastguard Worker
ToResizeMethod(armnnSerializer::ResizeMethod method)629*89c4ff92SAndroid Build Coastguard Worker armnn::ResizeMethod ToResizeMethod(armnnSerializer::ResizeMethod method)
630*89c4ff92SAndroid Build Coastguard Worker {
631*89c4ff92SAndroid Build Coastguard Worker switch (method)
632*89c4ff92SAndroid Build Coastguard Worker {
633*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ResizeMethod_NearestNeighbor:
634*89c4ff92SAndroid Build Coastguard Worker return armnn::ResizeMethod::NearestNeighbor;
635*89c4ff92SAndroid Build Coastguard Worker case armnnSerializer::ResizeMethod_Bilinear:
636*89c4ff92SAndroid Build Coastguard Worker return armnn::ResizeMethod::Bilinear;
637*89c4ff92SAndroid Build Coastguard Worker default:
638*89c4ff92SAndroid Build Coastguard Worker return armnn::ResizeMethod::NearestNeighbor;
639*89c4ff92SAndroid Build Coastguard Worker }
640*89c4ff92SAndroid Build Coastguard Worker }
641*89c4ff92SAndroid Build Coastguard Worker
ToTensorInfo(TensorRawPtr tensorPtr)642*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo ToTensorInfo(TensorRawPtr tensorPtr)
643*89c4ff92SAndroid Build Coastguard Worker {
644*89c4ff92SAndroid Build Coastguard Worker armnn::DataType type;
645*89c4ff92SAndroid Build Coastguard Worker CHECK_TENSOR_PTR(tensorPtr);
646*89c4ff92SAndroid Build Coastguard Worker
647*89c4ff92SAndroid Build Coastguard Worker switch (tensorPtr->dataType())
648*89c4ff92SAndroid Build Coastguard Worker {
649*89c4ff92SAndroid Build Coastguard Worker case DataType_QAsymmS8:
650*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::QAsymmS8;
651*89c4ff92SAndroid Build Coastguard Worker break;
652*89c4ff92SAndroid Build Coastguard Worker case DataType_QSymmS8:
653*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::QSymmS8;
654*89c4ff92SAndroid Build Coastguard Worker break;
655*89c4ff92SAndroid Build Coastguard Worker case DataType_QuantisedAsymm8:
656*89c4ff92SAndroid Build Coastguard Worker case DataType_QAsymmU8:
657*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::QAsymmU8;
658*89c4ff92SAndroid Build Coastguard Worker break;
659*89c4ff92SAndroid Build Coastguard Worker case DataType_QSymmS16:
660*89c4ff92SAndroid Build Coastguard Worker case DataType_QuantisedSymm16:
661*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::QSymmS16;
662*89c4ff92SAndroid Build Coastguard Worker break;
663*89c4ff92SAndroid Build Coastguard Worker case DataType_Signed32:
664*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::Signed32;
665*89c4ff92SAndroid Build Coastguard Worker break;
666*89c4ff92SAndroid Build Coastguard Worker case DataType_Signed64:
667*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::Signed64;
668*89c4ff92SAndroid Build Coastguard Worker break;
669*89c4ff92SAndroid Build Coastguard Worker case DataType_Float32:
670*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::Float32;
671*89c4ff92SAndroid Build Coastguard Worker break;
672*89c4ff92SAndroid Build Coastguard Worker case DataType_Float16:
673*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::Float16;
674*89c4ff92SAndroid Build Coastguard Worker break;
675*89c4ff92SAndroid Build Coastguard Worker case DataType_Boolean:
676*89c4ff92SAndroid Build Coastguard Worker type = armnn::DataType::Boolean;
677*89c4ff92SAndroid Build Coastguard Worker break;
678*89c4ff92SAndroid Build Coastguard Worker default:
679*89c4ff92SAndroid Build Coastguard Worker {
680*89c4ff92SAndroid Build Coastguard Worker CheckLocation location = CHECK_LOCATION();
681*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("Unsupported data type {0} = {1}. {2}",
682*89c4ff92SAndroid Build Coastguard Worker tensorPtr->dataType(),
683*89c4ff92SAndroid Build Coastguard Worker EnumNameDataType(tensorPtr->dataType()),
684*89c4ff92SAndroid Build Coastguard Worker location.AsString()));
685*89c4ff92SAndroid Build Coastguard Worker }
686*89c4ff92SAndroid Build Coastguard Worker }
687*89c4ff92SAndroid Build Coastguard Worker
688*89c4ff92SAndroid Build Coastguard Worker float quantizationScale = tensorPtr->quantizationScale();
689*89c4ff92SAndroid Build Coastguard Worker int32_t quantizationOffset = tensorPtr->quantizationOffset();
690*89c4ff92SAndroid Build Coastguard Worker
691*89c4ff92SAndroid Build Coastguard Worker if (tensorPtr->dimensionality() == static_cast<unsigned int>(Dimensionality::Scalar))
692*89c4ff92SAndroid Build Coastguard Worker {
693*89c4ff92SAndroid Build Coastguard Worker return armnn::TensorInfo(TensorShape{armnn::Dimensionality::Scalar},
694*89c4ff92SAndroid Build Coastguard Worker type,
695*89c4ff92SAndroid Build Coastguard Worker quantizationScale,
696*89c4ff92SAndroid Build Coastguard Worker quantizationOffset);
697*89c4ff92SAndroid Build Coastguard Worker }
698*89c4ff92SAndroid Build Coastguard Worker else if (tensorPtr->dimensionality() == static_cast<unsigned int>(Dimensionality::NotSpecified))
699*89c4ff92SAndroid Build Coastguard Worker {
700*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo result(TensorShape{Dimensionality::NotSpecified},
701*89c4ff92SAndroid Build Coastguard Worker type,
702*89c4ff92SAndroid Build Coastguard Worker quantizationScale,
703*89c4ff92SAndroid Build Coastguard Worker quantizationOffset);
704*89c4ff92SAndroid Build Coastguard Worker return result;
705*89c4ff92SAndroid Build Coastguard Worker }
706*89c4ff92SAndroid Build Coastguard Worker
707*89c4ff92SAndroid Build Coastguard Worker auto dimensions = tensorPtr->dimensions();
708*89c4ff92SAndroid Build Coastguard Worker unsigned int size = dimensions->size();
709*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> outputDims(dimensions->begin(), dimensions->begin() + size);
710*89c4ff92SAndroid Build Coastguard Worker bool dimensionsSpecificity[armnn::MaxNumOfTensorDimensions];
711*89c4ff92SAndroid Build Coastguard Worker std::fill_n(dimensionsSpecificity, armnn::MaxNumOfTensorDimensions, true);
712*89c4ff92SAndroid Build Coastguard Worker // For backwards compatibility check if the dimensionSpecificity vector is present first.
713*89c4ff92SAndroid Build Coastguard Worker // The default is to have dimensionSpecificity set to all true's anyway.
714*89c4ff92SAndroid Build Coastguard Worker if (tensorPtr->dimensionSpecificity() != nullptr)
715*89c4ff92SAndroid Build Coastguard Worker {
716*89c4ff92SAndroid Build Coastguard Worker auto dimensionSpecificity = tensorPtr->dimensionSpecificity();
717*89c4ff92SAndroid Build Coastguard Worker size = dimensionSpecificity->size();
718*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < size; ++i)
719*89c4ff92SAndroid Build Coastguard Worker {
720*89c4ff92SAndroid Build Coastguard Worker dimensionsSpecificity[i] = dimensionSpecificity->Get(i);
721*89c4ff92SAndroid Build Coastguard Worker }
722*89c4ff92SAndroid Build Coastguard Worker }
723*89c4ff92SAndroid Build Coastguard Worker // Construct a TensorShape
724*89c4ff92SAndroid Build Coastguard Worker TensorShape shape(size, outputDims.data(), dimensionsSpecificity);
725*89c4ff92SAndroid Build Coastguard Worker
726*89c4ff92SAndroid Build Coastguard Worker auto quantizationScales = tensorPtr->quantizationScales();
727*89c4ff92SAndroid Build Coastguard Worker if (quantizationScales)
728*89c4ff92SAndroid Build Coastguard Worker {
729*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationScalesSize = quantizationScales->size();
730*89c4ff92SAndroid Build Coastguard Worker std::vector<float> scales(quantizationScales->begin(), quantizationScales->begin() + quantizationScalesSize);
731*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationDim = tensorPtr->quantizationDim();
732*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo result(shape,
733*89c4ff92SAndroid Build Coastguard Worker type,
734*89c4ff92SAndroid Build Coastguard Worker scales,
735*89c4ff92SAndroid Build Coastguard Worker quantizationDim);
736*89c4ff92SAndroid Build Coastguard Worker return result;
737*89c4ff92SAndroid Build Coastguard Worker }
738*89c4ff92SAndroid Build Coastguard Worker
739*89c4ff92SAndroid Build Coastguard Worker // two statements (on purpose) for easier debugging:
740*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo result(shape,
741*89c4ff92SAndroid Build Coastguard Worker type,
742*89c4ff92SAndroid Build Coastguard Worker quantizationScale,
743*89c4ff92SAndroid Build Coastguard Worker quantizationOffset);
744*89c4ff92SAndroid Build Coastguard Worker
745*89c4ff92SAndroid Build Coastguard Worker return result;
746*89c4ff92SAndroid Build Coastguard Worker }
747*89c4ff92SAndroid Build Coastguard Worker
ToConstTensor(ConstTensorRawPtr constTensorPtr)748*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor ToConstTensor(ConstTensorRawPtr constTensorPtr)
749*89c4ff92SAndroid Build Coastguard Worker {
750*89c4ff92SAndroid Build Coastguard Worker CHECK_CONST_TENSOR_PTR(constTensorPtr);
751*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo = ToTensorInfo(constTensorPtr->info());
752*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetConstant();
753*89c4ff92SAndroid Build Coastguard Worker
754*89c4ff92SAndroid Build Coastguard Worker switch (constTensorPtr->data_type())
755*89c4ff92SAndroid Build Coastguard Worker {
756*89c4ff92SAndroid Build Coastguard Worker case ConstTensorData_ByteData:
757*89c4ff92SAndroid Build Coastguard Worker {
758*89c4ff92SAndroid Build Coastguard Worker auto byteData = constTensorPtr->data_as_ByteData()->data();
759*89c4ff92SAndroid Build Coastguard Worker CHECK_CONST_TENSOR_SIZE(byteData->size(), tensorInfo.GetNumElements());
760*89c4ff92SAndroid Build Coastguard Worker return armnn::ConstTensor(tensorInfo, byteData->data());
761*89c4ff92SAndroid Build Coastguard Worker }
762*89c4ff92SAndroid Build Coastguard Worker case ConstTensorData_ShortData:
763*89c4ff92SAndroid Build Coastguard Worker {
764*89c4ff92SAndroid Build Coastguard Worker auto shortData = constTensorPtr->data_as_ShortData()->data();
765*89c4ff92SAndroid Build Coastguard Worker CHECK_CONST_TENSOR_SIZE(shortData->size(), tensorInfo.GetNumElements());
766*89c4ff92SAndroid Build Coastguard Worker return armnn::ConstTensor(tensorInfo, shortData->data());
767*89c4ff92SAndroid Build Coastguard Worker }
768*89c4ff92SAndroid Build Coastguard Worker case ConstTensorData_IntData:
769*89c4ff92SAndroid Build Coastguard Worker {
770*89c4ff92SAndroid Build Coastguard Worker auto intData = constTensorPtr->data_as_IntData()->data();
771*89c4ff92SAndroid Build Coastguard Worker CHECK_CONST_TENSOR_SIZE(intData->size(), tensorInfo.GetNumElements());
772*89c4ff92SAndroid Build Coastguard Worker return armnn::ConstTensor(tensorInfo, intData->data());
773*89c4ff92SAndroid Build Coastguard Worker }
774*89c4ff92SAndroid Build Coastguard Worker case ConstTensorData_LongData:
775*89c4ff92SAndroid Build Coastguard Worker {
776*89c4ff92SAndroid Build Coastguard Worker auto longData = constTensorPtr->data_as_LongData()->data();
777*89c4ff92SAndroid Build Coastguard Worker CHECK_CONST_TENSOR_SIZE(longData->size(), tensorInfo.GetNumElements());
778*89c4ff92SAndroid Build Coastguard Worker return armnn::ConstTensor(tensorInfo, longData->data());
779*89c4ff92SAndroid Build Coastguard Worker }
780*89c4ff92SAndroid Build Coastguard Worker default:
781*89c4ff92SAndroid Build Coastguard Worker {
782*89c4ff92SAndroid Build Coastguard Worker CheckLocation location = CHECK_LOCATION();
783*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("Unsupported data type {0} = {1}. {2}",
784*89c4ff92SAndroid Build Coastguard Worker constTensorPtr->data_type(),
785*89c4ff92SAndroid Build Coastguard Worker EnumNameConstTensorData(constTensorPtr->data_type()),
786*89c4ff92SAndroid Build Coastguard Worker location.AsString()));
787*89c4ff92SAndroid Build Coastguard Worker }
788*89c4ff92SAndroid Build Coastguard Worker }
789*89c4ff92SAndroid Build Coastguard Worker }
790*89c4ff92SAndroid Build Coastguard Worker
GetInputs(const GraphPtr & graphPtr,unsigned int layerIndex)791*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector IDeserializer::DeserializerImpl::GetInputs(const GraphPtr& graphPtr, unsigned int layerIndex)
792*89c4ff92SAndroid Build Coastguard Worker {
793*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graphPtr, 0, layerIndex);
794*89c4ff92SAndroid Build Coastguard Worker auto layer = GetBaseLayer(graphPtr, layerIndex);
795*89c4ff92SAndroid Build Coastguard Worker const auto& numInputs = layer->inputSlots()->size();
796*89c4ff92SAndroid Build Coastguard Worker
797*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector result(numInputs);
798*89c4ff92SAndroid Build Coastguard Worker
799*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i<numInputs; ++i)
800*89c4ff92SAndroid Build Coastguard Worker {
801*89c4ff92SAndroid Build Coastguard Worker auto inputId = CHECKED_NON_NEGATIVE(static_cast<int32_t>
802*89c4ff92SAndroid Build Coastguard Worker (layer->inputSlots()->Get(i)->connection()->sourceLayerIndex()));
803*89c4ff92SAndroid Build Coastguard Worker result[i] = GetBaseLayer(graphPtr, inputId)->outputSlots()->Get(0)->tensorInfo();
804*89c4ff92SAndroid Build Coastguard Worker }
805*89c4ff92SAndroid Build Coastguard Worker return result;
806*89c4ff92SAndroid Build Coastguard Worker }
807*89c4ff92SAndroid Build Coastguard Worker
GetOutputs(const GraphPtr & graphPtr,unsigned int layerIndex)808*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector IDeserializer::DeserializerImpl::GetOutputs(const GraphPtr& graphPtr, unsigned int layerIndex)
809*89c4ff92SAndroid Build Coastguard Worker {
810*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graphPtr, 0, layerIndex);
811*89c4ff92SAndroid Build Coastguard Worker auto layer = GetBaseLayer(graphPtr, layerIndex);
812*89c4ff92SAndroid Build Coastguard Worker const auto& numOutputs = layer->outputSlots()->size();
813*89c4ff92SAndroid Build Coastguard Worker
814*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector result(numOutputs);
815*89c4ff92SAndroid Build Coastguard Worker
816*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i<numOutputs; ++i)
817*89c4ff92SAndroid Build Coastguard Worker {
818*89c4ff92SAndroid Build Coastguard Worker result[i] = layer->outputSlots()->Get(i)->tensorInfo();
819*89c4ff92SAndroid Build Coastguard Worker }
820*89c4ff92SAndroid Build Coastguard Worker return result;
821*89c4ff92SAndroid Build Coastguard Worker }
822*89c4ff92SAndroid Build Coastguard Worker
ParseUnsupportedLayer(GraphPtr graph,unsigned int layerIndex)823*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex)
824*89c4ff92SAndroid Build Coastguard Worker {
825*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
826*89c4ff92SAndroid Build Coastguard Worker const auto layerName = GetBaseLayer(graph, layerIndex)->layerName()->c_str();
827*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("Layer not supported. layerIndex: {0} "
828*89c4ff92SAndroid Build Coastguard Worker "layerName: {1} / {2}",
829*89c4ff92SAndroid Build Coastguard Worker layerIndex,
830*89c4ff92SAndroid Build Coastguard Worker layerName,
831*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
832*89c4ff92SAndroid Build Coastguard Worker }
833*89c4ff92SAndroid Build Coastguard Worker
ResetParser()834*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ResetParser()
835*89c4ff92SAndroid Build Coastguard Worker {
836*89c4ff92SAndroid Build Coastguard Worker m_Network = armnn::INetworkPtr(nullptr, nullptr);
837*89c4ff92SAndroid Build Coastguard Worker m_InputBindings.clear();
838*89c4ff92SAndroid Build Coastguard Worker m_OutputBindings.clear();
839*89c4ff92SAndroid Build Coastguard Worker }
840*89c4ff92SAndroid Build Coastguard Worker
841*89c4ff92SAndroid Build Coastguard Worker
CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent)842*89c4ff92SAndroid Build Coastguard Worker INetworkPtr IDeserializer::DeserializerImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
843*89c4ff92SAndroid Build Coastguard Worker {
844*89c4ff92SAndroid Build Coastguard Worker ResetParser();
845*89c4ff92SAndroid Build Coastguard Worker GraphPtr graph = LoadGraphFromBinary(binaryContent.data(), binaryContent.size());
846*89c4ff92SAndroid Build Coastguard Worker return CreateNetworkFromGraph(graph);
847*89c4ff92SAndroid Build Coastguard Worker }
848*89c4ff92SAndroid Build Coastguard Worker
CreateNetworkFromBinary(std::istream & binaryContent)849*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr IDeserializer::DeserializerImpl::CreateNetworkFromBinary(std::istream& binaryContent)
850*89c4ff92SAndroid Build Coastguard Worker {
851*89c4ff92SAndroid Build Coastguard Worker ResetParser();
852*89c4ff92SAndroid Build Coastguard Worker if (binaryContent.fail()) {
853*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << (std::string("Cannot read input"));
854*89c4ff92SAndroid Build Coastguard Worker throw ParseException("Unable to read Input stream data");
855*89c4ff92SAndroid Build Coastguard Worker }
856*89c4ff92SAndroid Build Coastguard Worker binaryContent.seekg(0, std::ios::end);
857*89c4ff92SAndroid Build Coastguard Worker const std::streamoff size = binaryContent.tellg();
858*89c4ff92SAndroid Build Coastguard Worker std::vector<char> content(static_cast<size_t>(size));
859*89c4ff92SAndroid Build Coastguard Worker binaryContent.seekg(0);
860*89c4ff92SAndroid Build Coastguard Worker binaryContent.read(content.data(), static_cast<std::streamsize>(size));
861*89c4ff92SAndroid Build Coastguard Worker GraphPtr graph = LoadGraphFromBinary(reinterpret_cast<uint8_t*>(content.data()), static_cast<size_t>(size));
862*89c4ff92SAndroid Build Coastguard Worker return CreateNetworkFromGraph(graph);
863*89c4ff92SAndroid Build Coastguard Worker }
864*89c4ff92SAndroid Build Coastguard Worker
LoadGraphFromBinary(const uint8_t * binaryContent,size_t len)865*89c4ff92SAndroid Build Coastguard Worker GraphPtr IDeserializer::DeserializerImpl::LoadGraphFromBinary(const uint8_t* binaryContent, size_t len)
866*89c4ff92SAndroid Build Coastguard Worker {
867*89c4ff92SAndroid Build Coastguard Worker if (binaryContent == nullptr)
868*89c4ff92SAndroid Build Coastguard Worker {
869*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
870*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
871*89c4ff92SAndroid Build Coastguard Worker }
872*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Verifier verifier(binaryContent, len);
873*89c4ff92SAndroid Build Coastguard Worker if (verifier.VerifyBuffer<SerializedGraph>() == false)
874*89c4ff92SAndroid Build Coastguard Worker {
875*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
876*89c4ff92SAndroid Build Coastguard Worker "flatbuffers format. size:{0} {1}",
877*89c4ff92SAndroid Build Coastguard Worker len,
878*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
879*89c4ff92SAndroid Build Coastguard Worker }
880*89c4ff92SAndroid Build Coastguard Worker return GetSerializedGraph(binaryContent);
881*89c4ff92SAndroid Build Coastguard Worker }
882*89c4ff92SAndroid Build Coastguard Worker
CreateNetworkFromGraph(GraphPtr graph)883*89c4ff92SAndroid Build Coastguard Worker INetworkPtr IDeserializer::DeserializerImpl::CreateNetworkFromGraph(GraphPtr graph)
884*89c4ff92SAndroid Build Coastguard Worker {
885*89c4ff92SAndroid Build Coastguard Worker m_Network = INetwork::Create();
886*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(graph != nullptr);
887*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex = 0;
888*89c4ff92SAndroid Build Coastguard Worker for (AnyLayer const* layer : *graph->layers())
889*89c4ff92SAndroid Build Coastguard Worker {
890*89c4ff92SAndroid Build Coastguard Worker if (layer->layer_type() != Layer_InputLayer &&
891*89c4ff92SAndroid Build Coastguard Worker layer->layer_type() != Layer_OutputLayer)
892*89c4ff92SAndroid Build Coastguard Worker {
893*89c4ff92SAndroid Build Coastguard Worker // lookup and call the parser function
894*89c4ff92SAndroid Build Coastguard Worker auto& parserFunction = m_ParserFunctions[layer->layer_type()];
895*89c4ff92SAndroid Build Coastguard Worker (this->*parserFunction)(graph, layerIndex);
896*89c4ff92SAndroid Build Coastguard Worker }
897*89c4ff92SAndroid Build Coastguard Worker ++layerIndex;
898*89c4ff92SAndroid Build Coastguard Worker }
899*89c4ff92SAndroid Build Coastguard Worker
900*89c4ff92SAndroid Build Coastguard Worker SetupInputLayers(graph);
901*89c4ff92SAndroid Build Coastguard Worker SetupOutputLayers(graph);
902*89c4ff92SAndroid Build Coastguard Worker
903*89c4ff92SAndroid Build Coastguard Worker // establish the connections from the layer outputs to the inputs of the subsequent layers
904*89c4ff92SAndroid Build Coastguard Worker for (auto&& graphIt : m_GraphConnections)
905*89c4ff92SAndroid Build Coastguard Worker {
906*89c4ff92SAndroid Build Coastguard Worker Connections& connections = graphIt.second;
907*89c4ff92SAndroid Build Coastguard Worker for (auto&& outputIt : connections.outputSlots)
908*89c4ff92SAndroid Build Coastguard Worker {
909*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputSlotIndex = outputIt.first;
910*89c4ff92SAndroid Build Coastguard Worker IOutputSlot* outputSlot = outputIt.second;
911*89c4ff92SAndroid Build Coastguard Worker if (connections.inputSlots.find(outputSlotIndex) != connections.inputSlots.end())
912*89c4ff92SAndroid Build Coastguard Worker {
913*89c4ff92SAndroid Build Coastguard Worker for (IInputSlot* inputSlot : connections.inputSlots[outputSlotIndex])
914*89c4ff92SAndroid Build Coastguard Worker {
915*89c4ff92SAndroid Build Coastguard Worker outputSlot->Connect(*inputSlot);
916*89c4ff92SAndroid Build Coastguard Worker }
917*89c4ff92SAndroid Build Coastguard Worker }
918*89c4ff92SAndroid Build Coastguard Worker }
919*89c4ff92SAndroid Build Coastguard Worker }
920*89c4ff92SAndroid Build Coastguard Worker
921*89c4ff92SAndroid Build Coastguard Worker return std::move(m_Network);
922*89c4ff92SAndroid Build Coastguard Worker }
923*89c4ff92SAndroid Build Coastguard Worker
GetNetworkInputBindingInfo(unsigned int layerIndex,const std::string & name) const924*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo IDeserializer::DeserializerImpl::GetNetworkInputBindingInfo(unsigned int layerIndex,
925*89c4ff92SAndroid Build Coastguard Worker const std::string& name) const
926*89c4ff92SAndroid Build Coastguard Worker {
927*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(layerIndex);
928*89c4ff92SAndroid Build Coastguard Worker for (auto inputBinding : m_InputBindings)
929*89c4ff92SAndroid Build Coastguard Worker {
930*89c4ff92SAndroid Build Coastguard Worker if (inputBinding.first == name)
931*89c4ff92SAndroid Build Coastguard Worker {
932*89c4ff92SAndroid Build Coastguard Worker return inputBinding.second;
933*89c4ff92SAndroid Build Coastguard Worker }
934*89c4ff92SAndroid Build Coastguard Worker }
935*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("No input binding found for layer:{0} / {1}",
936*89c4ff92SAndroid Build Coastguard Worker name,
937*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
938*89c4ff92SAndroid Build Coastguard Worker }
939*89c4ff92SAndroid Build Coastguard Worker
GetNetworkOutputBindingInfo(unsigned int layerIndex,const std::string & name) const940*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo IDeserializer::DeserializerImpl::GetNetworkOutputBindingInfo(unsigned int layerIndex,
941*89c4ff92SAndroid Build Coastguard Worker const std::string& name) const
942*89c4ff92SAndroid Build Coastguard Worker {
943*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(layerIndex);
944*89c4ff92SAndroid Build Coastguard Worker for (auto outputBinding : m_OutputBindings)
945*89c4ff92SAndroid Build Coastguard Worker {
946*89c4ff92SAndroid Build Coastguard Worker if (outputBinding.first == name)
947*89c4ff92SAndroid Build Coastguard Worker {
948*89c4ff92SAndroid Build Coastguard Worker return outputBinding.second;
949*89c4ff92SAndroid Build Coastguard Worker }
950*89c4ff92SAndroid Build Coastguard Worker }
951*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("No output binding found for layer:{0} / {1}",
952*89c4ff92SAndroid Build Coastguard Worker name,
953*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
954*89c4ff92SAndroid Build Coastguard Worker }
955*89c4ff92SAndroid Build Coastguard Worker
GetInputLayerInVector(GraphPtr graph,int targetId)956*89c4ff92SAndroid Build Coastguard Worker unsigned int IDeserializer::DeserializerImpl::GetInputLayerInVector(GraphPtr graph, int targetId)
957*89c4ff92SAndroid Build Coastguard Worker {
958*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < graph->layers()->size(); i++)
959*89c4ff92SAndroid Build Coastguard Worker {
960*89c4ff92SAndroid Build Coastguard Worker auto layer = graph->layers()->Get(i);
961*89c4ff92SAndroid Build Coastguard Worker if (layer->layer_type() == Layer::Layer_InputLayer)
962*89c4ff92SAndroid Build Coastguard Worker {
963*89c4ff92SAndroid Build Coastguard Worker auto layerBindingId = layer->layer_as_InputLayer()->base()->layerBindingId();
964*89c4ff92SAndroid Build Coastguard Worker if (layerBindingId == targetId)
965*89c4ff92SAndroid Build Coastguard Worker {
966*89c4ff92SAndroid Build Coastguard Worker return i;
967*89c4ff92SAndroid Build Coastguard Worker }
968*89c4ff92SAndroid Build Coastguard Worker }
969*89c4ff92SAndroid Build Coastguard Worker }
970*89c4ff92SAndroid Build Coastguard Worker throw ParseException("Input layer with given layerBindingId not found");
971*89c4ff92SAndroid Build Coastguard Worker }
972*89c4ff92SAndroid Build Coastguard Worker
GetOutputLayerInVector(GraphPtr graph,int targetId)973*89c4ff92SAndroid Build Coastguard Worker unsigned int IDeserializer::DeserializerImpl::GetOutputLayerInVector(GraphPtr graph, int targetId)
974*89c4ff92SAndroid Build Coastguard Worker {
975*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < graph->layers()->size(); i++)
976*89c4ff92SAndroid Build Coastguard Worker {
977*89c4ff92SAndroid Build Coastguard Worker auto layer = graph->layers()->Get(i);
978*89c4ff92SAndroid Build Coastguard Worker if (layer->layer_type() == Layer::Layer_OutputLayer)
979*89c4ff92SAndroid Build Coastguard Worker {
980*89c4ff92SAndroid Build Coastguard Worker auto layerBindingId = layer->layer_as_OutputLayer()->base()->layerBindingId();
981*89c4ff92SAndroid Build Coastguard Worker if (layerBindingId == targetId)
982*89c4ff92SAndroid Build Coastguard Worker {
983*89c4ff92SAndroid Build Coastguard Worker return i;
984*89c4ff92SAndroid Build Coastguard Worker }
985*89c4ff92SAndroid Build Coastguard Worker }
986*89c4ff92SAndroid Build Coastguard Worker }
987*89c4ff92SAndroid Build Coastguard Worker throw ParseException("Output layer with given layerBindingId not found");
988*89c4ff92SAndroid Build Coastguard Worker }
989*89c4ff92SAndroid Build Coastguard Worker
GetLayerIndexInVector(GraphPtr graph,unsigned int targetIndex)990*89c4ff92SAndroid Build Coastguard Worker unsigned int IDeserializer::DeserializerImpl::GetLayerIndexInVector(GraphPtr graph, unsigned int targetIndex)
991*89c4ff92SAndroid Build Coastguard Worker {
992*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < graph->layers()->size(); i++)
993*89c4ff92SAndroid Build Coastguard Worker {
994*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr layer = GetBaseLayer(graph, i);
995*89c4ff92SAndroid Build Coastguard Worker if (layer->index() == targetIndex)
996*89c4ff92SAndroid Build Coastguard Worker {
997*89c4ff92SAndroid Build Coastguard Worker return i;
998*89c4ff92SAndroid Build Coastguard Worker }
999*89c4ff92SAndroid Build Coastguard Worker }
1000*89c4ff92SAndroid Build Coastguard Worker throw ParseException("Layer with given index not found");
1001*89c4ff92SAndroid Build Coastguard Worker }
1002*89c4ff92SAndroid Build Coastguard Worker
GetFeatureVersions(GraphPtr graph)1003*89c4ff92SAndroid Build Coastguard Worker IDeserializer::DeserializerImpl::FeatureVersions IDeserializer::DeserializerImpl::GetFeatureVersions(GraphPtr graph)
1004*89c4ff92SAndroid Build Coastguard Worker {
1005*89c4ff92SAndroid Build Coastguard Worker IDeserializer::DeserializerImpl::FeatureVersions versions;
1006*89c4ff92SAndroid Build Coastguard Worker
1007*89c4ff92SAndroid Build Coastguard Worker if (graph->featureVersions())
1008*89c4ff92SAndroid Build Coastguard Worker {
1009*89c4ff92SAndroid Build Coastguard Worker versions.m_BindingIdScheme = graph->featureVersions()->bindingIdsScheme();
1010*89c4ff92SAndroid Build Coastguard Worker versions.m_WeightsLayoutScheme = graph->featureVersions()->weightsLayoutScheme();
1011*89c4ff92SAndroid Build Coastguard Worker versions.m_ConstTensorsAsInputs = graph->featureVersions()->constantTensorsAsInputs();
1012*89c4ff92SAndroid Build Coastguard Worker }
1013*89c4ff92SAndroid Build Coastguard Worker
1014*89c4ff92SAndroid Build Coastguard Worker return versions;
1015*89c4ff92SAndroid Build Coastguard Worker }
1016*89c4ff92SAndroid Build Coastguard Worker
SetupInputLayers(GraphPtr graph)1017*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::SetupInputLayers(GraphPtr graph)
1018*89c4ff92SAndroid Build Coastguard Worker {
1019*89c4ff92SAndroid Build Coastguard Worker CHECK_GRAPH(graph, 0);
1020*89c4ff92SAndroid Build Coastguard Worker const unsigned int numInputs = graph->inputIds()->size();
1021*89c4ff92SAndroid Build Coastguard Worker m_InputBindings.clear();
1022*89c4ff92SAndroid Build Coastguard Worker m_InputBindings.reserve(numInputs);
1023*89c4ff92SAndroid Build Coastguard Worker
1024*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; i++)
1025*89c4ff92SAndroid Build Coastguard Worker {
1026*89c4ff92SAndroid Build Coastguard Worker unsigned int inputLayerIndex = 0xFFFFFFFF;
1027*89c4ff92SAndroid Build Coastguard Worker if (GetFeatureVersions(graph).m_BindingIdScheme == 0)
1028*89c4ff92SAndroid Build Coastguard Worker {
1029*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputId = armnn::numeric_cast<unsigned int>(graph->inputIds()->Get(i));
1030*89c4ff92SAndroid Build Coastguard Worker inputLayerIndex = GetLayerIndexInVector(graph, inputId);
1031*89c4ff92SAndroid Build Coastguard Worker }
1032*89c4ff92SAndroid Build Coastguard Worker else
1033*89c4ff92SAndroid Build Coastguard Worker {
1034*89c4ff92SAndroid Build Coastguard Worker const int inputId = graph->inputIds()->Get(i);
1035*89c4ff92SAndroid Build Coastguard Worker inputLayerIndex = GetInputLayerInVector(graph, inputId);
1036*89c4ff92SAndroid Build Coastguard Worker }
1037*89c4ff92SAndroid Build Coastguard Worker
1038*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr baseLayer = GetBaseLayer(graph, inputLayerIndex);
1039*89c4ff92SAndroid Build Coastguard Worker
1040*89c4ff92SAndroid Build Coastguard Worker // GetBindingLayerInfo expect the index to be index in the vector not index property on each layer base
1041*89c4ff92SAndroid Build Coastguard Worker LayerBindingId bindingId = GetBindingLayerInfo(graph, inputLayerIndex);
1042*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(baseLayer->layerName()->c_str(), "Input has no name.");
1043*89c4ff92SAndroid Build Coastguard Worker
1044*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* inputLayer =
1045*89c4ff92SAndroid Build Coastguard Worker m_Network->AddInputLayer(bindingId, baseLayer->layerName()->c_str());
1046*89c4ff92SAndroid Build Coastguard Worker
1047*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo = ToTensorInfo(baseLayer->outputSlots()->Get(0)->tensorInfo());
1048*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
1049*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, inputLayerIndex, inputLayer);
1050*89c4ff92SAndroid Build Coastguard Worker
1051*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo bindingInfo = {bindingId, tensorInfo};
1052*89c4ff92SAndroid Build Coastguard Worker m_InputBindings.push_back(std::make_pair(baseLayer->layerName()->c_str(), bindingInfo));
1053*89c4ff92SAndroid Build Coastguard Worker }
1054*89c4ff92SAndroid Build Coastguard Worker }
1055*89c4ff92SAndroid Build Coastguard Worker
SetupOutputLayers(GraphPtr graph)1056*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::SetupOutputLayers(GraphPtr graph)
1057*89c4ff92SAndroid Build Coastguard Worker {
1058*89c4ff92SAndroid Build Coastguard Worker CHECK_GRAPH(graph, 0);
1059*89c4ff92SAndroid Build Coastguard Worker const unsigned int numOutputs = graph->outputIds()->size();
1060*89c4ff92SAndroid Build Coastguard Worker m_OutputBindings.clear();
1061*89c4ff92SAndroid Build Coastguard Worker m_OutputBindings.reserve(numOutputs);
1062*89c4ff92SAndroid Build Coastguard Worker
1063*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numOutputs; i++)
1064*89c4ff92SAndroid Build Coastguard Worker {
1065*89c4ff92SAndroid Build Coastguard Worker unsigned int outputLayerIndex = 0xFFFFFFFF;
1066*89c4ff92SAndroid Build Coastguard Worker if (GetFeatureVersions(graph).m_BindingIdScheme == 0)
1067*89c4ff92SAndroid Build Coastguard Worker {
1068*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputId = armnn::numeric_cast<unsigned int>(graph->outputIds()->Get(i));
1069*89c4ff92SAndroid Build Coastguard Worker outputLayerIndex = GetLayerIndexInVector(graph, outputId);
1070*89c4ff92SAndroid Build Coastguard Worker }
1071*89c4ff92SAndroid Build Coastguard Worker else
1072*89c4ff92SAndroid Build Coastguard Worker {
1073*89c4ff92SAndroid Build Coastguard Worker const int outputId = graph->outputIds()->Get(i);
1074*89c4ff92SAndroid Build Coastguard Worker outputLayerIndex = GetOutputLayerInVector(graph, outputId);
1075*89c4ff92SAndroid Build Coastguard Worker }
1076*89c4ff92SAndroid Build Coastguard Worker
1077*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr baseLayer = GetBaseLayer(graph, outputLayerIndex);
1078*89c4ff92SAndroid Build Coastguard Worker
1079*89c4ff92SAndroid Build Coastguard Worker // GetBindingLayerInfo expect the index to be index in the vector not index property on each layer base
1080*89c4ff92SAndroid Build Coastguard Worker LayerBindingId bindingId = GetBindingLayerInfo(graph, outputLayerIndex);
1081*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(baseLayer->layerName()->c_str(), "Output has no name.");
1082*89c4ff92SAndroid Build Coastguard Worker
1083*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* outputLayer =
1084*89c4ff92SAndroid Build Coastguard Worker m_Network->AddOutputLayer(bindingId, baseLayer->layerName()->c_str());
1085*89c4ff92SAndroid Build Coastguard Worker
1086*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, outputLayerIndex, outputLayer);
1087*89c4ff92SAndroid Build Coastguard Worker unsigned int sourceLayerIndex =
1088*89c4ff92SAndroid Build Coastguard Worker GetLayerIndexInVector(graph, baseLayer->inputSlots()->Get(0)->connection()->sourceLayerIndex());
1089*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSlotIndex =
1090*89c4ff92SAndroid Build Coastguard Worker GetLayerIndexInVector(graph, baseLayer->inputSlots()->Get(0)->connection()->outputSlotIndex());
1091*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr sourceBaseLayer = GetBaseLayer(graph, sourceLayerIndex);
1092*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo = ToTensorInfo(
1093*89c4ff92SAndroid Build Coastguard Worker sourceBaseLayer->outputSlots()->Get(outputSlotIndex)->tensorInfo());
1094*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo bindingInfo = {bindingId, tensorInfo};
1095*89c4ff92SAndroid Build Coastguard Worker m_OutputBindings.push_back(std::make_pair(baseLayer->layerName()->c_str(), bindingInfo));
1096*89c4ff92SAndroid Build Coastguard Worker }
1097*89c4ff92SAndroid Build Coastguard Worker }
1098*89c4ff92SAndroid Build Coastguard Worker
RegisterOutputSlots(GraphPtr graph,uint32_t layerIndex,IConnectableLayer * layer)1099*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::RegisterOutputSlots(GraphPtr graph,
1100*89c4ff92SAndroid Build Coastguard Worker uint32_t layerIndex,
1101*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer)
1102*89c4ff92SAndroid Build Coastguard Worker {
1103*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1104*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
1105*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr baseLayer = GetBaseLayer(graph, layerIndex);
1106*89c4ff92SAndroid Build Coastguard Worker if (baseLayer->outputSlots()->size() != layer->GetNumOutputSlots())
1107*89c4ff92SAndroid Build Coastguard Worker {
1108*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("The number of outputslots ({0}) does not match the number expected ({1})"
1109*89c4ff92SAndroid Build Coastguard Worker " for layer index: {2} {3}",
1110*89c4ff92SAndroid Build Coastguard Worker baseLayer->outputSlots()->size(),
1111*89c4ff92SAndroid Build Coastguard Worker layer->GetNumOutputSlots(),
1112*89c4ff92SAndroid Build Coastguard Worker layerIndex,
1113*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
1114*89c4ff92SAndroid Build Coastguard Worker }
1115*89c4ff92SAndroid Build Coastguard Worker
1116*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i)
1117*89c4ff92SAndroid Build Coastguard Worker {
1118*89c4ff92SAndroid Build Coastguard Worker const unsigned int slotIndex = baseLayer->outputSlots()->Get(i)->index();
1119*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* outputSlot = &(layer->GetOutputSlot(slotIndex));
1120*89c4ff92SAndroid Build Coastguard Worker // layerIndex is not necessarily the same as baseLayer->index(). The latter is needed here
1121*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlotOfConnection(baseLayer->index(), slotIndex, outputSlot);
1122*89c4ff92SAndroid Build Coastguard Worker }
1123*89c4ff92SAndroid Build Coastguard Worker }
1124*89c4ff92SAndroid Build Coastguard Worker
RegisterInputSlots(GraphPtr graph,uint32_t layerIndex,armnn::IConnectableLayer * layer,std::vector<unsigned int> ignoreSlots)1125*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::RegisterInputSlots(GraphPtr graph,
1126*89c4ff92SAndroid Build Coastguard Worker uint32_t layerIndex,
1127*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer,
1128*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots)
1129*89c4ff92SAndroid Build Coastguard Worker {
1130*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1131*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
1132*89c4ff92SAndroid Build Coastguard Worker LayerBaseRawPtr baseLayer = GetBaseLayer(graph, layerIndex);
1133*89c4ff92SAndroid Build Coastguard Worker
1134*89c4ff92SAndroid Build Coastguard Worker if (baseLayer->inputSlots()->size() != (layer->GetNumInputSlots() - ignoreSlots.size()))
1135*89c4ff92SAndroid Build Coastguard Worker {
1136*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("The number of inputslots ({0}) does not match the number expected ({1})"
1137*89c4ff92SAndroid Build Coastguard Worker " for layer index:{2} {3}",
1138*89c4ff92SAndroid Build Coastguard Worker baseLayer->inputSlots()->size(),
1139*89c4ff92SAndroid Build Coastguard Worker layer->GetNumInputSlots(),
1140*89c4ff92SAndroid Build Coastguard Worker layerIndex,
1141*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
1142*89c4ff92SAndroid Build Coastguard Worker }
1143*89c4ff92SAndroid Build Coastguard Worker
1144*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < layer->GetNumInputSlots(); ++i)
1145*89c4ff92SAndroid Build Coastguard Worker {
1146*89c4ff92SAndroid Build Coastguard Worker // Check if slot should be ignored.
1147*89c4ff92SAndroid Build Coastguard Worker if (std::find(ignoreSlots.begin(), ignoreSlots.end(), i) == ignoreSlots.end())
1148*89c4ff92SAndroid Build Coastguard Worker {
1149*89c4ff92SAndroid Build Coastguard Worker auto fbInputSlot = baseLayer->inputSlots()->Get(i);
1150*89c4ff92SAndroid Build Coastguard Worker auto fbConnection = fbInputSlot->connection();
1151*89c4ff92SAndroid Build Coastguard Worker armnn::IInputSlot* inputSlot = &(layer->GetInputSlot(fbInputSlot->index()));
1152*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlotOfConnection(fbConnection->sourceLayerIndex(), fbConnection->outputSlotIndex(), inputSlot);
1153*89c4ff92SAndroid Build Coastguard Worker }
1154*89c4ff92SAndroid Build Coastguard Worker }
1155*89c4ff92SAndroid Build Coastguard Worker }
1156*89c4ff92SAndroid Build Coastguard Worker
RegisterInputSlotOfConnection(uint32_t sourceLayerIndex,uint32_t outputSlotIndex,armnn::IInputSlot * inputSlot)1157*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::RegisterInputSlotOfConnection(uint32_t sourceLayerIndex,
1158*89c4ff92SAndroid Build Coastguard Worker uint32_t outputSlotIndex,
1159*89c4ff92SAndroid Build Coastguard Worker armnn::IInputSlot* inputSlot)
1160*89c4ff92SAndroid Build Coastguard Worker {
1161*89c4ff92SAndroid Build Coastguard Worker if (m_GraphConnections.find(sourceLayerIndex) == m_GraphConnections.end())
1162*89c4ff92SAndroid Build Coastguard Worker {
1163*89c4ff92SAndroid Build Coastguard Worker m_GraphConnections[sourceLayerIndex] = Connections();
1164*89c4ff92SAndroid Build Coastguard Worker }
1165*89c4ff92SAndroid Build Coastguard Worker
1166*89c4ff92SAndroid Build Coastguard Worker Connections& connections = m_GraphConnections[sourceLayerIndex];
1167*89c4ff92SAndroid Build Coastguard Worker if (connections.inputSlots.find(outputSlotIndex) == connections.inputSlots.end())
1168*89c4ff92SAndroid Build Coastguard Worker {
1169*89c4ff92SAndroid Build Coastguard Worker connections.inputSlots[outputSlotIndex] = {inputSlot};
1170*89c4ff92SAndroid Build Coastguard Worker }
1171*89c4ff92SAndroid Build Coastguard Worker else
1172*89c4ff92SAndroid Build Coastguard Worker {
1173*89c4ff92SAndroid Build Coastguard Worker connections.inputSlots[outputSlotIndex].push_back(inputSlot);
1174*89c4ff92SAndroid Build Coastguard Worker }
1175*89c4ff92SAndroid Build Coastguard Worker }
1176*89c4ff92SAndroid Build Coastguard Worker
RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex,uint32_t outputSlotIndex,armnn::IOutputSlot * outputSlot)1177*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex,
1178*89c4ff92SAndroid Build Coastguard Worker uint32_t outputSlotIndex,
1179*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* outputSlot)
1180*89c4ff92SAndroid Build Coastguard Worker {
1181*89c4ff92SAndroid Build Coastguard Worker if (m_GraphConnections.find(sourceLayerIndex) == m_GraphConnections.end())
1182*89c4ff92SAndroid Build Coastguard Worker {
1183*89c4ff92SAndroid Build Coastguard Worker m_GraphConnections[sourceLayerIndex] = Connections();
1184*89c4ff92SAndroid Build Coastguard Worker }
1185*89c4ff92SAndroid Build Coastguard Worker
1186*89c4ff92SAndroid Build Coastguard Worker Connections& connections = m_GraphConnections[sourceLayerIndex];
1187*89c4ff92SAndroid Build Coastguard Worker if (connections.outputSlots.find(outputSlotIndex) != connections.outputSlots.end())
1188*89c4ff92SAndroid Build Coastguard Worker {
1189*89c4ff92SAndroid Build Coastguard Worker throw ParseException("Same output slot index processed twice");
1190*89c4ff92SAndroid Build Coastguard Worker }
1191*89c4ff92SAndroid Build Coastguard Worker
1192*89c4ff92SAndroid Build Coastguard Worker connections.outputSlots[outputSlotIndex] = outputSlot;
1193*89c4ff92SAndroid Build Coastguard Worker }
1194*89c4ff92SAndroid Build Coastguard Worker
ParseAbs(GraphPtr graph,unsigned int layerIndex)1195*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseAbs(GraphPtr graph, unsigned int layerIndex)
1196*89c4ff92SAndroid Build Coastguard Worker {
1197*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1198*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1199*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1200*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1201*89c4ff92SAndroid Build Coastguard Worker
1202*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1203*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1204*89c4ff92SAndroid Build Coastguard Worker
1205*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1206*89c4ff92SAndroid Build Coastguard Worker
1207*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseUnaryDescriptor descriptor(armnn::UnaryOperation::Abs);
1208*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str());
1209*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1210*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1211*89c4ff92SAndroid Build Coastguard Worker
1212*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1213*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1214*89c4ff92SAndroid Build Coastguard Worker }
1215*89c4ff92SAndroid Build Coastguard Worker
ParseActivation(GraphPtr graph,unsigned int layerIndex)1216*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseActivation(GraphPtr graph, unsigned int layerIndex)
1217*89c4ff92SAndroid Build Coastguard Worker {
1218*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1219*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1220*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1221*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1222*89c4ff92SAndroid Build Coastguard Worker
1223*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1224*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1225*89c4ff92SAndroid Build Coastguard Worker
1226*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_ActivationLayer();
1227*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1228*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
1229*89c4ff92SAndroid Build Coastguard Worker
1230*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationDescriptor descriptor;
1231*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Function = ToActivationFunction(serializerDescriptor->activationFunction());
1232*89c4ff92SAndroid Build Coastguard Worker descriptor.m_A = serializerDescriptor->a();
1233*89c4ff92SAndroid Build Coastguard Worker descriptor.m_B = serializerDescriptor->b();
1234*89c4ff92SAndroid Build Coastguard Worker
1235*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddActivationLayer(descriptor,
1236*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1237*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1238*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1239*89c4ff92SAndroid Build Coastguard Worker
1240*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1241*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1242*89c4ff92SAndroid Build Coastguard Worker }
1243*89c4ff92SAndroid Build Coastguard Worker
ParseAdd(GraphPtr graph,unsigned int layerIndex)1244*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseAdd(GraphPtr graph, unsigned int layerIndex)
1245*89c4ff92SAndroid Build Coastguard Worker {
1246*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1247*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1248*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1249*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1250*89c4ff92SAndroid Build Coastguard Worker
1251*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1252*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1253*89c4ff92SAndroid Build Coastguard Worker
1254*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1255*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Add);
1256*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
1257*89c4ff92SAndroid Build Coastguard Worker
1258*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1259*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1260*89c4ff92SAndroid Build Coastguard Worker
1261*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1262*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1263*89c4ff92SAndroid Build Coastguard Worker }
1264*89c4ff92SAndroid Build Coastguard Worker
ParseArgMinMax(GraphPtr graph,unsigned int layerIndex)1265*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseArgMinMax(GraphPtr graph, unsigned int layerIndex)
1266*89c4ff92SAndroid Build Coastguard Worker {
1267*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1268*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1269*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1270*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1271*89c4ff92SAndroid Build Coastguard Worker
1272*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1273*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1274*89c4ff92SAndroid Build Coastguard Worker
1275*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_ArgMinMaxLayer();
1276*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
1277*89c4ff92SAndroid Build Coastguard Worker
1278*89c4ff92SAndroid Build Coastguard Worker armnn::ArgMinMaxDescriptor descriptor;
1279*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Function = ToArgMinMaxFunction(serializerDescriptor->argMinMaxFunction());
1280*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = serializerDescriptor->axis();
1281*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1282*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddArgMinMaxLayer(descriptor, layerName.c_str());
1283*89c4ff92SAndroid Build Coastguard Worker
1284*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1285*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1286*89c4ff92SAndroid Build Coastguard Worker
1287*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1288*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1289*89c4ff92SAndroid Build Coastguard Worker }
1290*89c4ff92SAndroid Build Coastguard Worker
ParseBatchMatMul(GraphPtr graph,unsigned int layerIndex)1291*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseBatchMatMul(GraphPtr graph, unsigned int layerIndex)
1292*89c4ff92SAndroid Build Coastguard Worker {
1293*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1294*89c4ff92SAndroid Build Coastguard Worker
1295*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1296*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1297*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1298*89c4ff92SAndroid Build Coastguard Worker
1299*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1300*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1301*89c4ff92SAndroid Build Coastguard Worker
1302*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_BatchMatMulLayer();
1303*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
1304*89c4ff92SAndroid Build Coastguard Worker
1305*89c4ff92SAndroid Build Coastguard Worker armnn::BatchMatMulDescriptor descriptor(serializerDescriptor->transposeX(),
1306*89c4ff92SAndroid Build Coastguard Worker serializerDescriptor->transposeY(),
1307*89c4ff92SAndroid Build Coastguard Worker serializerDescriptor->adjointX(),
1308*89c4ff92SAndroid Build Coastguard Worker serializerDescriptor->adjointY(),
1309*89c4ff92SAndroid Build Coastguard Worker ToDataLayout(serializerDescriptor->dataLayoutX()),
1310*89c4ff92SAndroid Build Coastguard Worker ToDataLayout(serializerDescriptor->dataLayoutY()));
1311*89c4ff92SAndroid Build Coastguard Worker
1312*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1313*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
1314*89c4ff92SAndroid Build Coastguard Worker
1315*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1316*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1317*89c4ff92SAndroid Build Coastguard Worker
1318*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1319*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1320*89c4ff92SAndroid Build Coastguard Worker }
1321*89c4ff92SAndroid Build Coastguard Worker
ParseBatchToSpaceNd(GraphPtr graph,unsigned int layerIndex)1322*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex)
1323*89c4ff92SAndroid Build Coastguard Worker {
1324*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1325*89c4ff92SAndroid Build Coastguard Worker
1326*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
1327*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1328*89c4ff92SAndroid Build Coastguard Worker
1329*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
1330*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1331*89c4ff92SAndroid Build Coastguard Worker
1332*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->descriptor();
1333*89c4ff92SAndroid Build Coastguard Worker auto flatBufferCrops = flatBufferDescriptor->crops();
1334*89c4ff92SAndroid Build Coastguard Worker auto flatBufferBlockShape = flatBufferDescriptor->blockShape();
1335*89c4ff92SAndroid Build Coastguard Worker
1336*89c4ff92SAndroid Build Coastguard Worker if (flatBufferCrops->size() % 2 != 0)
1337*89c4ff92SAndroid Build Coastguard Worker {
1338*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("The size of crops must be divisible by 2 {}", CHECK_LOCATION().AsString()));
1339*89c4ff92SAndroid Build Coastguard Worker }
1340*89c4ff92SAndroid Build Coastguard Worker
1341*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<unsigned int, unsigned int>> crops;
1342*89c4ff92SAndroid Build Coastguard Worker crops.reserve(flatBufferCrops->size() / 2);
1343*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < flatBufferCrops->size() - 1; i += 2)
1344*89c4ff92SAndroid Build Coastguard Worker {
1345*89c4ff92SAndroid Build Coastguard Worker crops.emplace_back(flatBufferCrops->Get(i), flatBufferCrops->Get(i+1));
1346*89c4ff92SAndroid Build Coastguard Worker }
1347*89c4ff92SAndroid Build Coastguard Worker
1348*89c4ff92SAndroid Build Coastguard Worker armnn::BatchToSpaceNdDescriptor descriptor;
1349*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
1350*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BlockShape =
1351*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int>(flatBufferBlockShape->begin(), flatBufferBlockShape->end());
1352*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Crops = crops;
1353*89c4ff92SAndroid Build Coastguard Worker
1354*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1355*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddBatchToSpaceNdLayer(descriptor, layerName.c_str());
1356*89c4ff92SAndroid Build Coastguard Worker
1357*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1358*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1359*89c4ff92SAndroid Build Coastguard Worker
1360*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1361*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1362*89c4ff92SAndroid Build Coastguard Worker }
1363*89c4ff92SAndroid Build Coastguard Worker
ParseBatchNormalization(GraphPtr graph,unsigned int layerIndex)1364*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseBatchNormalization(GraphPtr graph, unsigned int layerIndex)
1365*89c4ff92SAndroid Build Coastguard Worker {
1366*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1367*89c4ff92SAndroid Build Coastguard Worker
1368*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1369*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1370*89c4ff92SAndroid Build Coastguard Worker
1371*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1372*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1373*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
1374*89c4ff92SAndroid Build Coastguard Worker
1375*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1376*89c4ff92SAndroid Build Coastguard Worker
1377*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer();
1378*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
1379*89c4ff92SAndroid Build Coastguard Worker
1380*89c4ff92SAndroid Build Coastguard Worker armnn::BatchNormalizationDescriptor descriptor;
1381*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Eps = serializerDescriptor->eps();
1382*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(serializerDescriptor->dataLayout());
1383*89c4ff92SAndroid Build Coastguard Worker
1384*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor mean = ToConstTensor(serializerLayer->mean());
1385*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor variance = ToConstTensor(serializerLayer->variance());
1386*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor beta = ToConstTensor(serializerLayer->beta());
1387*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor gamma = ToConstTensor(serializerLayer->gamma());
1388*89c4ff92SAndroid Build Coastguard Worker
1389*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(descriptor,
1390*89c4ff92SAndroid Build Coastguard Worker mean,
1391*89c4ff92SAndroid Build Coastguard Worker variance,
1392*89c4ff92SAndroid Build Coastguard Worker beta,
1393*89c4ff92SAndroid Build Coastguard Worker gamma,
1394*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1395*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1396*89c4ff92SAndroid Build Coastguard Worker
1397*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1398*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1399*89c4ff92SAndroid Build Coastguard Worker }
1400*89c4ff92SAndroid Build Coastguard Worker
ParseCast(GraphPtr graph,unsigned int layerIndex)1401*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseCast(GraphPtr graph, unsigned int layerIndex)
1402*89c4ff92SAndroid Build Coastguard Worker {
1403*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1404*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
1405*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1406*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1407*89c4ff92SAndroid Build Coastguard Worker
1408*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
1409*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1410*89c4ff92SAndroid Build Coastguard Worker
1411*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1412*89c4ff92SAndroid Build Coastguard Worker
1413*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddCastLayer(layerName.c_str());
1414*89c4ff92SAndroid Build Coastguard Worker
1415*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1416*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1417*89c4ff92SAndroid Build Coastguard Worker
1418*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1419*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1420*89c4ff92SAndroid Build Coastguard Worker }
1421*89c4ff92SAndroid Build Coastguard Worker
ParseConstant(GraphPtr graph,unsigned int layerIndex)1422*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseConstant(GraphPtr graph, unsigned int layerIndex)
1423*89c4ff92SAndroid Build Coastguard Worker {
1424*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1425*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1426*89c4ff92SAndroid Build Coastguard Worker
1427*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1428*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1429*89c4ff92SAndroid Build Coastguard Worker
1430*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1431*89c4ff92SAndroid Build Coastguard Worker
1432*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_ConstantLayer();
1433*89c4ff92SAndroid Build Coastguard Worker auto serializerInput = serializerLayer->input();
1434*89c4ff92SAndroid Build Coastguard Worker
1435*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor input = ToConstTensor(serializerInput);
1436*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer;
1437*89c4ff92SAndroid Build Coastguard Worker
1438*89c4ff92SAndroid Build Coastguard Worker // Required for when Constant Layer is used as an inputs to DepthwiseConvolution2d Layer.
1439*89c4ff92SAndroid Build Coastguard Worker // Running a model that was created before weights layout scheme version was added to our flatbuffers
1440*89c4ff92SAndroid Build Coastguard Worker // file ensuring older models can still be read and executed. featureVersion weights layout scheme 1
1441*89c4ff92SAndroid Build Coastguard Worker // indicates a change in the depthwise weights layout within ArmNN from [M,I,H,W] --> [1,H,W,I*M]
1442*89c4ff92SAndroid Build Coastguard Worker if (this->GetFeatureVersions(graph).m_WeightsLayoutScheme <= 0)
1443*89c4ff92SAndroid Build Coastguard Worker {
1444*89c4ff92SAndroid Build Coastguard Worker // Permute weights [ H, W, M, I ] --> [ 1, H, W, I*M ]
1445*89c4ff92SAndroid Build Coastguard Worker // Step1: [ M, I, H, W ] --> [ H, W, I, M]
1446*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector = { 3, 2, 0, 1 };
1447*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsInfo = input.GetInfo();
1448*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<unsigned char[]> permuteBuffer(new unsigned char[weightsInfo.GetNumBytes()]);
1449*89c4ff92SAndroid Build Coastguard Worker weightsInfo = armnnUtils::Permuted(weightsInfo, permutationVector);
1450*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(weightsInfo.GetShape(), permutationVector,
1451*89c4ff92SAndroid Build Coastguard Worker input.GetMemoryArea(), permuteBuffer.get(),
1452*89c4ff92SAndroid Build Coastguard Worker GetDataTypeSize(weightsInfo.GetDataType()));
1453*89c4ff92SAndroid Build Coastguard Worker
1454*89c4ff92SAndroid Build Coastguard Worker // Step2: Reshape [ H, W, I, M] --> [ 1, H, W, I*M ]
1455*89c4ff92SAndroid Build Coastguard Worker auto weightsShape = weightsInfo.GetShape();
1456*89c4ff92SAndroid Build Coastguard Worker weightsInfo.SetShape({1,
1457*89c4ff92SAndroid Build Coastguard Worker weightsShape[0],
1458*89c4ff92SAndroid Build Coastguard Worker weightsShape[1],
1459*89c4ff92SAndroid Build Coastguard Worker weightsShape[2]*weightsShape[3]});
1460*89c4ff92SAndroid Build Coastguard Worker weightsInfo.SetConstant(true);
1461*89c4ff92SAndroid Build Coastguard Worker
1462*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weightsPermuted(weightsInfo, permuteBuffer.get());
1463*89c4ff92SAndroid Build Coastguard Worker
1464*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddConstantLayer(weightsPermuted, layerName.c_str());
1465*89c4ff92SAndroid Build Coastguard Worker
1466*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(weightsPermuted.GetInfo());
1467*89c4ff92SAndroid Build Coastguard Worker
1468*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1469*89c4ff92SAndroid Build Coastguard Worker
1470*89c4ff92SAndroid Build Coastguard Worker return;
1471*89c4ff92SAndroid Build Coastguard Worker }
1472*89c4ff92SAndroid Build Coastguard Worker else
1473*89c4ff92SAndroid Build Coastguard Worker {
1474*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddConstantLayer(input, layerName.c_str());
1475*89c4ff92SAndroid Build Coastguard Worker
1476*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1477*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.SetConstant(true);
1478*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1479*89c4ff92SAndroid Build Coastguard Worker }
1480*89c4ff92SAndroid Build Coastguard Worker
1481*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1482*89c4ff92SAndroid Build Coastguard Worker }
1483*89c4ff92SAndroid Build Coastguard Worker
ParseConvolution2d(GraphPtr graph,unsigned int layerIndex)1484*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseConvolution2d(GraphPtr graph, unsigned int layerIndex)
1485*89c4ff92SAndroid Build Coastguard Worker {
1486*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1487*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1488*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1489*89c4ff92SAndroid Build Coastguard Worker
1490*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1491*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1492*89c4ff92SAndroid Build Coastguard Worker
1493*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_Convolution2dLayer();
1494*89c4ff92SAndroid Build Coastguard Worker
1495*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1496*89c4ff92SAndroid Build Coastguard Worker auto flatbufferDescriptor = flatBufferLayer->descriptor();
1497*89c4ff92SAndroid Build Coastguard Worker
1498*89c4ff92SAndroid Build Coastguard Worker armnn::Convolution2dDescriptor descriptor;
1499*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = flatbufferDescriptor->padLeft();
1500*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = flatbufferDescriptor->padRight();
1501*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = flatbufferDescriptor->padTop();
1502*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = flatbufferDescriptor->padBottom();
1503*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = flatbufferDescriptor->strideX();
1504*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = flatbufferDescriptor->strideY();;
1505*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX = flatbufferDescriptor->dilationX();
1506*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY = flatbufferDescriptor->dilationY();;
1507*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = flatbufferDescriptor->biasEnabled();;
1508*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatbufferDescriptor->dataLayout());
1509*89c4ff92SAndroid Build Coastguard Worker
1510*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer;
1511*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots {};
1512*89c4ff92SAndroid Build Coastguard Worker
1513*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor biasTensor;
1514*89c4ff92SAndroid Build Coastguard Worker // Weights and biases used to be always constant and were stored as members of the layer. This has changed and
1515*89c4ff92SAndroid Build Coastguard Worker // they are now passed as inputs. If they are constant then they will be stored in a ConstantLayer.
1516*89c4ff92SAndroid Build Coastguard Worker if (this->GetFeatureVersions(graph).m_ConstTensorsAsInputs <= 0)
1517*89c4ff92SAndroid Build Coastguard Worker {
1518*89c4ff92SAndroid Build Coastguard Worker // If the model stores weights and biases as members of the layer we have to read them from there
1519*89c4ff92SAndroid Build Coastguard Worker // but add them to their own ConstantLayer for compatibility
1520*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1521*89c4ff92SAndroid Build Coastguard Worker
1522*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddConvolution2dLayer(descriptor,
1523*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1524*89c4ff92SAndroid Build Coastguard Worker
1525*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weightsTensor = ToConstTensor(flatBufferLayer->weights());
1526*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = m_Network->AddConstantLayer(weightsTensor);
1527*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1528*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensor.GetInfo());
1529*89c4ff92SAndroid Build Coastguard Worker ignoreSlots.emplace_back(1u);
1530*89c4ff92SAndroid Build Coastguard Worker
1531*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_BiasEnabled)
1532*89c4ff92SAndroid Build Coastguard Worker {
1533*89c4ff92SAndroid Build Coastguard Worker biasTensor = ToConstTensor(flatBufferLayer->biases());
1534*89c4ff92SAndroid Build Coastguard Worker auto biasLayer = m_Network->AddConstantLayer(biasTensor);
1535*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1536*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.GetInfo());
1537*89c4ff92SAndroid Build Coastguard Worker ignoreSlots.emplace_back(2u);
1538*89c4ff92SAndroid Build Coastguard Worker }
1539*89c4ff92SAndroid Build Coastguard Worker }
1540*89c4ff92SAndroid Build Coastguard Worker else
1541*89c4ff92SAndroid Build Coastguard Worker {
1542*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddConvolution2dLayer(descriptor,
1543*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1544*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = descriptor.GetNumInputs();
1545*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), numInputs);
1546*89c4ff92SAndroid Build Coastguard Worker }
1547*89c4ff92SAndroid Build Coastguard Worker
1548*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1549*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1550*89c4ff92SAndroid Build Coastguard Worker
1551*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer, ignoreSlots);
1552*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1553*89c4ff92SAndroid Build Coastguard Worker }
1554*89c4ff92SAndroid Build Coastguard Worker
ParseConvolution3d(GraphPtr graph,unsigned int layerIndex)1555*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseConvolution3d(GraphPtr graph, unsigned int layerIndex)
1556*89c4ff92SAndroid Build Coastguard Worker {
1557*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1558*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1559*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1560*89c4ff92SAndroid Build Coastguard Worker
1561*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1562*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1563*89c4ff92SAndroid Build Coastguard Worker
1564*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_Convolution3dLayer();
1565*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1566*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
1567*89c4ff92SAndroid Build Coastguard Worker
1568*89c4ff92SAndroid Build Coastguard Worker armnn::Convolution3dDescriptor descriptor;
1569*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = serializerDescriptor->padLeft();
1570*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = serializerDescriptor->padRight();
1571*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = serializerDescriptor->padTop();
1572*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = serializerDescriptor->padBottom();
1573*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadFront = serializerDescriptor->padFront();
1574*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBack = serializerDescriptor->padBack();
1575*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = serializerDescriptor->strideX();
1576*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = serializerDescriptor->strideY();
1577*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideZ = serializerDescriptor->strideZ();
1578*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX = serializerDescriptor->dilationX();
1579*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY = serializerDescriptor->dilationY();
1580*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationZ = serializerDescriptor->dilationZ();
1581*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = serializerDescriptor->biasEnabled();
1582*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(serializerDescriptor->dataLayout());
1583*89c4ff92SAndroid Build Coastguard Worker
1584*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = descriptor.GetNumInputs();
1585*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), numInputs);
1586*89c4ff92SAndroid Build Coastguard Worker
1587*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddConvolution3dLayer(descriptor, layerName.c_str());
1588*89c4ff92SAndroid Build Coastguard Worker
1589*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1590*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1591*89c4ff92SAndroid Build Coastguard Worker
1592*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1593*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1594*89c4ff92SAndroid Build Coastguard Worker }
1595*89c4ff92SAndroid Build Coastguard Worker
ParseDepthToSpace(GraphPtr graph,unsigned int layerIndex)1596*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseDepthToSpace(GraphPtr graph, unsigned int layerIndex)
1597*89c4ff92SAndroid Build Coastguard Worker {
1598*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1599*89c4ff92SAndroid Build Coastguard Worker
1600*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1601*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1602*89c4ff92SAndroid Build Coastguard Worker
1603*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1604*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1605*89c4ff92SAndroid Build Coastguard Worker
1606*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = graph->layers()->Get(layerIndex)->layer_as_DepthToSpaceLayer()->descriptor();
1607*89c4ff92SAndroid Build Coastguard Worker
1608*89c4ff92SAndroid Build Coastguard Worker armnn::DepthToSpaceDescriptor descriptor;
1609*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BlockSize = fbDescriptor->blockSize();
1610*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(fbDescriptor->dataLayout());
1611*89c4ff92SAndroid Build Coastguard Worker
1612*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1613*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddDepthToSpaceLayer(descriptor, layerName.c_str());
1614*89c4ff92SAndroid Build Coastguard Worker
1615*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo = ToTensorInfo(outputs[0]);
1616*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1617*89c4ff92SAndroid Build Coastguard Worker
1618*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1619*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1620*89c4ff92SAndroid Build Coastguard Worker }
1621*89c4ff92SAndroid Build Coastguard Worker
ParseDepthwiseConvolution2d(GraphPtr graph,unsigned int layerIndex)1622*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseDepthwiseConvolution2d(GraphPtr graph, unsigned int layerIndex)
1623*89c4ff92SAndroid Build Coastguard Worker {
1624*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1625*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1626*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1627*89c4ff92SAndroid Build Coastguard Worker
1628*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1629*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1630*89c4ff92SAndroid Build Coastguard Worker
1631*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_DepthwiseConvolution2dLayer();
1632*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1633*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
1634*89c4ff92SAndroid Build Coastguard Worker
1635*89c4ff92SAndroid Build Coastguard Worker armnn::DepthwiseConvolution2dDescriptor descriptor;
1636*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = serializerDescriptor->padLeft();
1637*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = serializerDescriptor->padRight();
1638*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = serializerDescriptor->padTop();
1639*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = serializerDescriptor->padBottom();
1640*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = serializerDescriptor->strideX();
1641*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = serializerDescriptor->strideY();
1642*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX = serializerDescriptor->dilationX();
1643*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY = serializerDescriptor->dilationY();
1644*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = serializerDescriptor->biasEnabled();
1645*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(serializerDescriptor->dataLayout());
1646*89c4ff92SAndroid Build Coastguard Worker
1647*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer;
1648*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots {};
1649*89c4ff92SAndroid Build Coastguard Worker
1650*89c4ff92SAndroid Build Coastguard Worker // Weights and biases used to be always constant and were stored as members of the layer. This has changed and
1651*89c4ff92SAndroid Build Coastguard Worker // they are now passed as inputs. If they are constant then they will be stored in a ConstantLayer.
1652*89c4ff92SAndroid Build Coastguard Worker if (this->GetFeatureVersions(graph).m_ConstTensorsAsInputs <= 0)
1653*89c4ff92SAndroid Build Coastguard Worker {
1654*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1655*89c4ff92SAndroid Build Coastguard Worker
1656*89c4ff92SAndroid Build Coastguard Worker // If the model stores weights and biases as members of the layer we have to read them from there
1657*89c4ff92SAndroid Build Coastguard Worker // but add them to their own ConstantLayer for compatibility
1658*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weights = ToConstTensor(serializerLayer->weights());
1659*89c4ff92SAndroid Build Coastguard Worker ignoreSlots.emplace_back(1u);
1660*89c4ff92SAndroid Build Coastguard Worker
1661*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddDepthwiseConvolution2dLayer(descriptor,
1662*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1663*89c4ff92SAndroid Build Coastguard Worker
1664*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional();
1665*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_BiasEnabled)
1666*89c4ff92SAndroid Build Coastguard Worker {
1667*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor biases = ToConstTensor(serializerLayer->biases());
1668*89c4ff92SAndroid Build Coastguard Worker ignoreSlots.emplace_back(2u);
1669*89c4ff92SAndroid Build Coastguard Worker
1670*89c4ff92SAndroid Build Coastguard Worker auto biasLayer = m_Network->AddConstantLayer(biases);
1671*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1672*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biases.GetInfo());
1673*89c4ff92SAndroid Build Coastguard Worker }
1674*89c4ff92SAndroid Build Coastguard Worker
1675*89c4ff92SAndroid Build Coastguard Worker if (this->GetFeatureVersions(graph).m_WeightsLayoutScheme <= 0)
1676*89c4ff92SAndroid Build Coastguard Worker {
1677*89c4ff92SAndroid Build Coastguard Worker // Permute weights [ H, W, M, I ] --> [ 1, H, W, I*M ]
1678*89c4ff92SAndroid Build Coastguard Worker // Step1: [ M, I, H, W ] --> [ H, W, I, M]
1679*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector = { 3, 2, 0, 1 };
1680*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsInfo = weights.GetInfo();
1681*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<unsigned char[]> permuteBuffer(new unsigned char[weightsInfo.GetNumBytes()]);
1682*89c4ff92SAndroid Build Coastguard Worker weightsInfo = armnnUtils::Permuted(weightsInfo, permutationVector);
1683*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(weightsInfo.GetShape(), permutationVector,
1684*89c4ff92SAndroid Build Coastguard Worker weights.GetMemoryArea(), permuteBuffer.get(),
1685*89c4ff92SAndroid Build Coastguard Worker GetDataTypeSize(weightsInfo.GetDataType()));
1686*89c4ff92SAndroid Build Coastguard Worker
1687*89c4ff92SAndroid Build Coastguard Worker // Step2: Reshape [ H, W, I, M] --> [ 1, H, W, I*M ]
1688*89c4ff92SAndroid Build Coastguard Worker auto weightsShape = weightsInfo.GetShape();
1689*89c4ff92SAndroid Build Coastguard Worker weightsInfo.SetShape({1,
1690*89c4ff92SAndroid Build Coastguard Worker weightsShape[0],
1691*89c4ff92SAndroid Build Coastguard Worker weightsShape[1],
1692*89c4ff92SAndroid Build Coastguard Worker weightsShape[2]*weightsShape[3]});
1693*89c4ff92SAndroid Build Coastguard Worker
1694*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weightsPermuted(weightsInfo, permuteBuffer.get());
1695*89c4ff92SAndroid Build Coastguard Worker
1696*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = m_Network->AddConstantLayer(weightsPermuted);
1697*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1698*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsPermuted.GetInfo());
1699*89c4ff92SAndroid Build Coastguard Worker }
1700*89c4ff92SAndroid Build Coastguard Worker else
1701*89c4ff92SAndroid Build Coastguard Worker {
1702*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = m_Network->AddConstantLayer(weights);
1703*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1704*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weights.GetInfo());
1705*89c4ff92SAndroid Build Coastguard Worker }
1706*89c4ff92SAndroid Build Coastguard Worker }
1707*89c4ff92SAndroid Build Coastguard Worker else
1708*89c4ff92SAndroid Build Coastguard Worker {
1709*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddDepthwiseConvolution2dLayer(descriptor,
1710*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1711*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = descriptor.GetNumInputs();
1712*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), numInputs);
1713*89c4ff92SAndroid Build Coastguard Worker }
1714*89c4ff92SAndroid Build Coastguard Worker
1715*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1716*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1717*89c4ff92SAndroid Build Coastguard Worker
1718*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer, ignoreSlots);
1719*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1720*89c4ff92SAndroid Build Coastguard Worker }
1721*89c4ff92SAndroid Build Coastguard Worker
ParseDetectionPostProcess(GraphPtr graph,unsigned int layerIndex)1722*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseDetectionPostProcess(GraphPtr graph, unsigned int layerIndex)
1723*89c4ff92SAndroid Build Coastguard Worker {
1724*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1725*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1726*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1727*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1728*89c4ff92SAndroid Build Coastguard Worker
1729*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1730*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 4);
1731*89c4ff92SAndroid Build Coastguard Worker
1732*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_DetectionPostProcessLayer();
1733*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1734*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = flatBufferLayer->descriptor();
1735*89c4ff92SAndroid Build Coastguard Worker
1736*89c4ff92SAndroid Build Coastguard Worker armnn::DetectionPostProcessDescriptor descriptor;
1737*89c4ff92SAndroid Build Coastguard Worker descriptor.m_MaxDetections = flatBufferDescriptor->maxDetections();
1738*89c4ff92SAndroid Build Coastguard Worker descriptor.m_MaxClassesPerDetection = flatBufferDescriptor->maxClassesPerDetection();
1739*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DetectionsPerClass = flatBufferDescriptor->detectionsPerClass();
1740*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NmsScoreThreshold = flatBufferDescriptor->nmsScoreThreshold();
1741*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NmsIouThreshold = flatBufferDescriptor->nmsIouThreshold();
1742*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NumClasses = flatBufferDescriptor->numClasses();
1743*89c4ff92SAndroid Build Coastguard Worker descriptor.m_UseRegularNms = flatBufferDescriptor->useRegularNms();
1744*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ScaleX = flatBufferDescriptor->scaleX();
1745*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ScaleY = flatBufferDescriptor->scaleY();
1746*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ScaleW = flatBufferDescriptor->scaleW();
1747*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ScaleH = flatBufferDescriptor->scaleH();
1748*89c4ff92SAndroid Build Coastguard Worker
1749*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor anchors = ToConstTensor(flatBufferLayer->anchors());
1750*89c4ff92SAndroid Build Coastguard Worker
1751*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddDetectionPostProcessLayer(descriptor,
1752*89c4ff92SAndroid Build Coastguard Worker anchors,
1753*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
1754*89c4ff92SAndroid Build Coastguard Worker
1755*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < 4; i++)
1756*89c4ff92SAndroid Build Coastguard Worker {
1757*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(i).SetTensorInfo(ToTensorInfo(outputs[i]));
1758*89c4ff92SAndroid Build Coastguard Worker }
1759*89c4ff92SAndroid Build Coastguard Worker
1760*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1761*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1762*89c4ff92SAndroid Build Coastguard Worker }
1763*89c4ff92SAndroid Build Coastguard Worker
ParseDivision(GraphPtr graph,unsigned int layerIndex)1764*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseDivision(GraphPtr graph, unsigned int layerIndex)
1765*89c4ff92SAndroid Build Coastguard Worker {
1766*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1767*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1768*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1769*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1770*89c4ff92SAndroid Build Coastguard Worker
1771*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1772*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1773*89c4ff92SAndroid Build Coastguard Worker
1774*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1775*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Div);
1776*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
1777*89c4ff92SAndroid Build Coastguard Worker
1778*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1779*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1780*89c4ff92SAndroid Build Coastguard Worker
1781*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1782*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1783*89c4ff92SAndroid Build Coastguard Worker }
1784*89c4ff92SAndroid Build Coastguard Worker
ParseEqual(GraphPtr graph,unsigned int layerIndex)1785*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseEqual(GraphPtr graph, unsigned int layerIndex)
1786*89c4ff92SAndroid Build Coastguard Worker {
1787*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1788*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1789*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1790*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1791*89c4ff92SAndroid Build Coastguard Worker
1792*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1793*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1794*89c4ff92SAndroid Build Coastguard Worker
1795*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1796*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor descriptor(armnn::ComparisonOperation::Equal);
1797*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddComparisonLayer(descriptor, layerName.c_str());
1798*89c4ff92SAndroid Build Coastguard Worker
1799*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1800*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1801*89c4ff92SAndroid Build Coastguard Worker
1802*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1803*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1804*89c4ff92SAndroid Build Coastguard Worker }
1805*89c4ff92SAndroid Build Coastguard Worker
ParseFill(GraphPtr graph,unsigned int layerIndex)1806*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseFill(GraphPtr graph, unsigned int layerIndex)
1807*89c4ff92SAndroid Build Coastguard Worker {
1808*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1809*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1810*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1811*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1812*89c4ff92SAndroid Build Coastguard Worker
1813*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1814*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1815*89c4ff92SAndroid Build Coastguard Worker
1816*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1817*89c4ff92SAndroid Build Coastguard Worker armnn::FillDescriptor descriptor;
1818*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Value = graph->layers()->Get(layerIndex)->layer_as_FillLayer()->descriptor()->value();
1819*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddFillLayer(descriptor, layerName.c_str());
1820*89c4ff92SAndroid Build Coastguard Worker
1821*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1822*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1823*89c4ff92SAndroid Build Coastguard Worker
1824*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1825*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1826*89c4ff92SAndroid Build Coastguard Worker }
1827*89c4ff92SAndroid Build Coastguard Worker
ParseGreater(GraphPtr graph,unsigned int layerIndex)1828*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseGreater(GraphPtr graph, unsigned int layerIndex)
1829*89c4ff92SAndroid Build Coastguard Worker {
1830*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1831*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1832*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1833*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1834*89c4ff92SAndroid Build Coastguard Worker
1835*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1836*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1837*89c4ff92SAndroid Build Coastguard Worker
1838*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1839*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor descriptor(armnn::ComparisonOperation::Greater);
1840*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddComparisonLayer(descriptor, layerName.c_str());
1841*89c4ff92SAndroid Build Coastguard Worker
1842*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1843*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1844*89c4ff92SAndroid Build Coastguard Worker
1845*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1846*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1847*89c4ff92SAndroid Build Coastguard Worker }
1848*89c4ff92SAndroid Build Coastguard Worker
ParseInstanceNormalization(GraphPtr graph,unsigned int layerIndex)1849*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseInstanceNormalization(GraphPtr graph, unsigned int layerIndex)
1850*89c4ff92SAndroid Build Coastguard Worker {
1851*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1852*89c4ff92SAndroid Build Coastguard Worker
1853*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1854*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1855*89c4ff92SAndroid Build Coastguard Worker
1856*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1857*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1858*89c4ff92SAndroid Build Coastguard Worker
1859*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_InstanceNormalizationLayer();
1860*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
1861*89c4ff92SAndroid Build Coastguard Worker
1862*89c4ff92SAndroid Build Coastguard Worker armnn::InstanceNormalizationDescriptor descriptor;
1863*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Gamma = fbDescriptor->gamma();
1864*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Beta = fbDescriptor->beta();
1865*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Eps = fbDescriptor->eps();
1866*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(fbDescriptor->dataLayout());
1867*89c4ff92SAndroid Build Coastguard Worker
1868*89c4ff92SAndroid Build Coastguard Worker const std::string layerName = GetLayerName(graph, layerIndex);
1869*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputInfo = ToTensorInfo(outputs[0]);
1870*89c4ff92SAndroid Build Coastguard Worker
1871*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddInstanceNormalizationLayer(descriptor, layerName.c_str());
1872*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1873*89c4ff92SAndroid Build Coastguard Worker
1874*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1875*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1876*89c4ff92SAndroid Build Coastguard Worker }
1877*89c4ff92SAndroid Build Coastguard Worker
ParseL2Normalization(GraphPtr graph,unsigned int layerIndex)1878*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseL2Normalization(GraphPtr graph, unsigned int layerIndex)
1879*89c4ff92SAndroid Build Coastguard Worker {
1880*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1881*89c4ff92SAndroid Build Coastguard Worker
1882*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1883*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1884*89c4ff92SAndroid Build Coastguard Worker
1885*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1886*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1887*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
1888*89c4ff92SAndroid Build Coastguard Worker
1889*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_L2NormalizationLayer();
1890*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = flatBufferLayer->descriptor();
1891*89c4ff92SAndroid Build Coastguard Worker
1892*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1893*89c4ff92SAndroid Build Coastguard Worker armnn::L2NormalizationDescriptor descriptor;
1894*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
1895*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Eps = flatBufferDescriptor->eps();
1896*89c4ff92SAndroid Build Coastguard Worker
1897*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddL2NormalizationLayer(descriptor, layerName.c_str());
1898*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1899*89c4ff92SAndroid Build Coastguard Worker
1900*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1901*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1902*89c4ff92SAndroid Build Coastguard Worker }
1903*89c4ff92SAndroid Build Coastguard Worker
ParseLogicalBinary(GraphPtr graph,unsigned int layerIndex)1904*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseLogicalBinary(GraphPtr graph, unsigned int layerIndex)
1905*89c4ff92SAndroid Build Coastguard Worker {
1906*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1907*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1908*89c4ff92SAndroid Build Coastguard Worker
1909*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1910*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1911*89c4ff92SAndroid Build Coastguard Worker
1912*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1913*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1914*89c4ff92SAndroid Build Coastguard Worker
1915*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_LogicalBinaryLayer();
1916*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
1917*89c4ff92SAndroid Build Coastguard Worker
1918*89c4ff92SAndroid Build Coastguard Worker armnn::LogicalBinaryDescriptor descriptor;
1919*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Operation = ToLogicalBinaryOperation(fbDescriptor->operation());
1920*89c4ff92SAndroid Build Coastguard Worker
1921*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName = GetLayerName(graph, layerIndex);
1922*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddLogicalBinaryLayer(descriptor, layerName.c_str());
1923*89c4ff92SAndroid Build Coastguard Worker
1924*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1925*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1926*89c4ff92SAndroid Build Coastguard Worker
1927*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1928*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1929*89c4ff92SAndroid Build Coastguard Worker }
1930*89c4ff92SAndroid Build Coastguard Worker
ParseLogSoftmax(GraphPtr graph,unsigned int layerIndex)1931*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseLogSoftmax(GraphPtr graph, unsigned int layerIndex)
1932*89c4ff92SAndroid Build Coastguard Worker {
1933*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1934*89c4ff92SAndroid Build Coastguard Worker
1935*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
1936*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
1937*89c4ff92SAndroid Build Coastguard Worker
1938*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
1939*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1940*89c4ff92SAndroid Build Coastguard Worker
1941*89c4ff92SAndroid Build Coastguard Worker armnn::LogSoftmaxDescriptor descriptor;
1942*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Beta = graph->layers()->Get(layerIndex)->layer_as_LogSoftmaxLayer()->descriptor()->beta();
1943*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_LogSoftmaxLayer()->descriptor()->axis();
1944*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1945*89c4ff92SAndroid Build Coastguard Worker
1946*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddLogSoftmaxLayer(descriptor, layerName.c_str());
1947*89c4ff92SAndroid Build Coastguard Worker
1948*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1949*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1950*89c4ff92SAndroid Build Coastguard Worker
1951*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1952*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1953*89c4ff92SAndroid Build Coastguard Worker }
1954*89c4ff92SAndroid Build Coastguard Worker
ParseMinimum(GraphPtr graph,unsigned int layerIndex)1955*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseMinimum(GraphPtr graph, unsigned int layerIndex)
1956*89c4ff92SAndroid Build Coastguard Worker {
1957*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1958*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1959*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1960*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1961*89c4ff92SAndroid Build Coastguard Worker
1962*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1963*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1964*89c4ff92SAndroid Build Coastguard Worker
1965*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1966*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Minimum);
1967*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
1968*89c4ff92SAndroid Build Coastguard Worker
1969*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1970*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1971*89c4ff92SAndroid Build Coastguard Worker
1972*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1973*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1974*89c4ff92SAndroid Build Coastguard Worker }
1975*89c4ff92SAndroid Build Coastguard Worker
ParseMaximum(GraphPtr graph,unsigned int layerIndex)1976*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseMaximum(GraphPtr graph, unsigned int layerIndex)
1977*89c4ff92SAndroid Build Coastguard Worker {
1978*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
1979*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
1980*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
1981*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
1982*89c4ff92SAndroid Build Coastguard Worker
1983*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
1984*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
1985*89c4ff92SAndroid Build Coastguard Worker
1986*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
1987*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Maximum);
1988*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
1989*89c4ff92SAndroid Build Coastguard Worker
1990*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
1991*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1992*89c4ff92SAndroid Build Coastguard Worker
1993*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
1994*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
1995*89c4ff92SAndroid Build Coastguard Worker }
1996*89c4ff92SAndroid Build Coastguard Worker
GetOriginsDescriptor(const armnnSerializer::SerializedGraph * graph,unsigned int layerIndex)1997*89c4ff92SAndroid Build Coastguard Worker const armnnSerializer::OriginsDescriptor* GetOriginsDescriptor(const armnnSerializer::SerializedGraph* graph,
1998*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex)
1999*89c4ff92SAndroid Build Coastguard Worker {
2000*89c4ff92SAndroid Build Coastguard Worker auto layerType = graph->layers()->Get(layerIndex)->layer_type();
2001*89c4ff92SAndroid Build Coastguard Worker
2002*89c4ff92SAndroid Build Coastguard Worker switch (layerType)
2003*89c4ff92SAndroid Build Coastguard Worker {
2004*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_ConcatLayer:
2005*89c4ff92SAndroid Build Coastguard Worker return graph->layers()->Get(layerIndex)->layer_as_ConcatLayer()->descriptor();
2006*89c4ff92SAndroid Build Coastguard Worker case Layer::Layer_MergerLayer:
2007*89c4ff92SAndroid Build Coastguard Worker return graph->layers()->Get(layerIndex)->layer_as_MergerLayer()->descriptor();
2008*89c4ff92SAndroid Build Coastguard Worker default:
2009*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("unknown layer type, should be concat or merger");
2010*89c4ff92SAndroid Build Coastguard Worker }
2011*89c4ff92SAndroid Build Coastguard Worker }
ParseChannelShuffle(GraphPtr graph,unsigned int layerIndex)2012*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseChannelShuffle(GraphPtr graph, unsigned int layerIndex)
2013*89c4ff92SAndroid Build Coastguard Worker {
2014*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2015*89c4ff92SAndroid Build Coastguard Worker
2016*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2017*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2018*89c4ff92SAndroid Build Coastguard Worker
2019*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2020*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2021*89c4ff92SAndroid Build Coastguard Worker
2022*89c4ff92SAndroid Build Coastguard Worker armnn::ChannelShuffleDescriptor descriptor;
2023*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->descriptor()->axis();
2024*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NumGroups =
2025*89c4ff92SAndroid Build Coastguard Worker graph->layers()->Get(layerIndex)->layer_as_ChannelShuffleLayer()->descriptor()->numGroups();
2026*89c4ff92SAndroid Build Coastguard Worker
2027*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2028*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddChannelShuffleLayer(descriptor, layerName.c_str());
2029*89c4ff92SAndroid Build Coastguard Worker
2030*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2031*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2032*89c4ff92SAndroid Build Coastguard Worker
2033*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2034*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2035*89c4ff92SAndroid Build Coastguard Worker }
ParseComparison(GraphPtr graph,unsigned int layerIndex)2036*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseComparison(GraphPtr graph, unsigned int layerIndex)
2037*89c4ff92SAndroid Build Coastguard Worker {
2038*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2039*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2040*89c4ff92SAndroid Build Coastguard Worker
2041*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2042*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
2043*89c4ff92SAndroid Build Coastguard Worker
2044*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2045*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2046*89c4ff92SAndroid Build Coastguard Worker
2047*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ComparisonLayer();
2048*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
2049*89c4ff92SAndroid Build Coastguard Worker
2050*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor descriptor;
2051*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Operation = ToComparisonOperation(fbDescriptor->operation());
2052*89c4ff92SAndroid Build Coastguard Worker
2053*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName = GetLayerName(graph, layerIndex);
2054*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddComparisonLayer(descriptor, layerName.c_str());
2055*89c4ff92SAndroid Build Coastguard Worker
2056*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2057*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2058*89c4ff92SAndroid Build Coastguard Worker
2059*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2060*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2061*89c4ff92SAndroid Build Coastguard Worker }
2062*89c4ff92SAndroid Build Coastguard Worker
ParseElementwiseBinary(GraphPtr graph,unsigned int layerIndex)2063*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseElementwiseBinary(GraphPtr graph, unsigned int layerIndex)
2064*89c4ff92SAndroid Build Coastguard Worker {
2065*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2066*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2067*89c4ff92SAndroid Build Coastguard Worker
2068*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2069*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
2070*89c4ff92SAndroid Build Coastguard Worker
2071*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2072*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2073*89c4ff92SAndroid Build Coastguard Worker
2074*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ElementwiseBinaryLayer();
2075*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
2076*89c4ff92SAndroid Build Coastguard Worker
2077*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor;
2078*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Operation = ToElementwiseBinaryOperation(fbDescriptor->operation());
2079*89c4ff92SAndroid Build Coastguard Worker
2080*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName = GetLayerName(graph, layerIndex);
2081*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
2082*89c4ff92SAndroid Build Coastguard Worker
2083*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2084*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2085*89c4ff92SAndroid Build Coastguard Worker
2086*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2087*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2088*89c4ff92SAndroid Build Coastguard Worker }
2089*89c4ff92SAndroid Build Coastguard Worker
ParseElementwiseUnary(GraphPtr graph,unsigned int layerIndex)2090*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex)
2091*89c4ff92SAndroid Build Coastguard Worker {
2092*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2093*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2094*89c4ff92SAndroid Build Coastguard Worker
2095*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2096*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2097*89c4ff92SAndroid Build Coastguard Worker
2098*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2099*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2100*89c4ff92SAndroid Build Coastguard Worker
2101*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ElementwiseUnaryLayer();
2102*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
2103*89c4ff92SAndroid Build Coastguard Worker
2104*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseUnaryDescriptor descriptor;
2105*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Operation = ToElementwiseUnaryOperation(fbDescriptor->operation());
2106*89c4ff92SAndroid Build Coastguard Worker
2107*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName = GetLayerName(graph, layerIndex);
2108*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str());
2109*89c4ff92SAndroid Build Coastguard Worker
2110*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2111*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2112*89c4ff92SAndroid Build Coastguard Worker
2113*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2114*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2115*89c4ff92SAndroid Build Coastguard Worker }
2116*89c4ff92SAndroid Build Coastguard Worker
ParseConcat(GraphPtr graph,unsigned int layerIndex)2117*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseConcat(GraphPtr graph, unsigned int layerIndex)
2118*89c4ff92SAndroid Build Coastguard Worker {
2119*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2120*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2121*89c4ff92SAndroid Build Coastguard Worker
2122*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2123*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2124*89c4ff92SAndroid Build Coastguard Worker
2125*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2126*89c4ff92SAndroid Build Coastguard Worker auto originsDescriptor = GetOriginsDescriptor(graph, layerIndex);
2127*89c4ff92SAndroid Build Coastguard Worker unsigned int numViews = originsDescriptor->numViews();
2128*89c4ff92SAndroid Build Coastguard Worker unsigned int numDimensions = originsDescriptor->numDimensions();
2129*89c4ff92SAndroid Build Coastguard Worker
2130*89c4ff92SAndroid Build Coastguard Worker // can now check the number of inputs == number of views
2131*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2132*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), numViews);
2133*89c4ff92SAndroid Build Coastguard Worker
2134*89c4ff92SAndroid Build Coastguard Worker armnn::OriginsDescriptor descriptor(numViews, numDimensions);
2135*89c4ff92SAndroid Build Coastguard Worker auto originsPtr = originsDescriptor->viewOrigins();
2136*89c4ff92SAndroid Build Coastguard Worker for (unsigned int v = 0; v < numViews; ++v)
2137*89c4ff92SAndroid Build Coastguard Worker {
2138*89c4ff92SAndroid Build Coastguard Worker auto originPtr = originsPtr->Get(v);
2139*89c4ff92SAndroid Build Coastguard Worker for (unsigned int d = 0; d < numDimensions; ++d)
2140*89c4ff92SAndroid Build Coastguard Worker {
2141*89c4ff92SAndroid Build Coastguard Worker uint32_t value = originPtr->data()->Get(d);
2142*89c4ff92SAndroid Build Coastguard Worker descriptor.SetViewOriginCoord(v, d, value);
2143*89c4ff92SAndroid Build Coastguard Worker }
2144*89c4ff92SAndroid Build Coastguard Worker }
2145*89c4ff92SAndroid Build Coastguard Worker descriptor.SetConcatAxis(originsDescriptor->concatAxis());
2146*89c4ff92SAndroid Build Coastguard Worker
2147*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddConcatLayer(descriptor, layerName.c_str());
2148*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2149*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2150*89c4ff92SAndroid Build Coastguard Worker
2151*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2152*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2153*89c4ff92SAndroid Build Coastguard Worker }
2154*89c4ff92SAndroid Build Coastguard Worker
ParseMultiplication(GraphPtr graph,unsigned int layerIndex)2155*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseMultiplication(GraphPtr graph, unsigned int layerIndex)
2156*89c4ff92SAndroid Build Coastguard Worker {
2157*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2158*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2159*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2160*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
2161*89c4ff92SAndroid Build Coastguard Worker
2162*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2163*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2164*89c4ff92SAndroid Build Coastguard Worker
2165*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2166*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Mul);
2167*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
2168*89c4ff92SAndroid Build Coastguard Worker
2169*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2170*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2171*89c4ff92SAndroid Build Coastguard Worker
2172*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2173*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2174*89c4ff92SAndroid Build Coastguard Worker }
2175*89c4ff92SAndroid Build Coastguard Worker
ParseFloor(GraphPtr graph,unsigned int layerIndex)2176*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseFloor(GraphPtr graph, unsigned int layerIndex)
2177*89c4ff92SAndroid Build Coastguard Worker {
2178*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2179*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2180*89c4ff92SAndroid Build Coastguard Worker
2181*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2182*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2183*89c4ff92SAndroid Build Coastguard Worker
2184*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2185*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2186*89c4ff92SAndroid Build Coastguard Worker
2187*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2188*89c4ff92SAndroid Build Coastguard Worker
2189*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer;
2190*89c4ff92SAndroid Build Coastguard Worker
2191*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddFloorLayer(layerName.c_str());
2192*89c4ff92SAndroid Build Coastguard Worker
2193*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2194*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2195*89c4ff92SAndroid Build Coastguard Worker
2196*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2197*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2198*89c4ff92SAndroid Build Coastguard Worker }
2199*89c4ff92SAndroid Build Coastguard Worker
ParseFullyConnected(GraphPtr graph,unsigned int layerIndex)2200*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseFullyConnected(GraphPtr graph, unsigned int layerIndex)
2201*89c4ff92SAndroid Build Coastguard Worker {
2202*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2203*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2204*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2205*89c4ff92SAndroid Build Coastguard Worker
2206*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2207*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2208*89c4ff92SAndroid Build Coastguard Worker
2209*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_FullyConnectedLayer();
2210*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2211*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = flatBufferLayer->descriptor();
2212*89c4ff92SAndroid Build Coastguard Worker
2213*89c4ff92SAndroid Build Coastguard Worker armnn::FullyConnectedDescriptor fullyConnectedDescriptor;
2214*89c4ff92SAndroid Build Coastguard Worker fullyConnectedDescriptor.m_BiasEnabled = flatBufferDescriptor->biasEnabled();
2215*89c4ff92SAndroid Build Coastguard Worker fullyConnectedDescriptor.m_TransposeWeightMatrix = flatBufferDescriptor->transposeWeightsMatrix();
2216*89c4ff92SAndroid Build Coastguard Worker fullyConnectedDescriptor.m_ConstantWeights = flatBufferDescriptor->constantWeights();
2217*89c4ff92SAndroid Build Coastguard Worker
2218*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer;
2219*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots {};
2220*89c4ff92SAndroid Build Coastguard Worker
2221*89c4ff92SAndroid Build Coastguard Worker // Weights and biases used to be always constant and were stored as members of the layer. This has changed and
2222*89c4ff92SAndroid Build Coastguard Worker // they are now passed as inputs. If they are constant then they will be stored in a ConstantLayer.
2223*89c4ff92SAndroid Build Coastguard Worker if (this->GetFeatureVersions(graph).m_ConstTensorsAsInputs <= 0)
2224*89c4ff92SAndroid Build Coastguard Worker {
2225*89c4ff92SAndroid Build Coastguard Worker // If the model stores weights and biases as members of the layer we have to read them from there
2226*89c4ff92SAndroid Build Coastguard Worker // but add them to their own ConstantLayer for compatibility
2227*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2228*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor,
2229*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
2230*89c4ff92SAndroid Build Coastguard Worker
2231*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weightsTensor = ToConstTensor(flatBufferLayer->weights());
2232*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = m_Network->AddConstantLayer(weightsTensor);
2233*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
2234*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsTensor.GetInfo());
2235*89c4ff92SAndroid Build Coastguard Worker ignoreSlots.emplace_back(1u);
2236*89c4ff92SAndroid Build Coastguard Worker
2237*89c4ff92SAndroid Build Coastguard Worker if (fullyConnectedDescriptor.m_BiasEnabled)
2238*89c4ff92SAndroid Build Coastguard Worker {
2239*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor biasTensor = ToConstTensor(flatBufferLayer->biases());
2240*89c4ff92SAndroid Build Coastguard Worker auto biasLayer = m_Network->AddConstantLayer(biasTensor);
2241*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2242*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.GetInfo());
2243*89c4ff92SAndroid Build Coastguard Worker ignoreSlots.emplace_back(2u);
2244*89c4ff92SAndroid Build Coastguard Worker }
2245*89c4ff92SAndroid Build Coastguard Worker }
2246*89c4ff92SAndroid Build Coastguard Worker else
2247*89c4ff92SAndroid Build Coastguard Worker {
2248*89c4ff92SAndroid Build Coastguard Worker layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor,
2249*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
2250*89c4ff92SAndroid Build Coastguard Worker uint32_t numInputs = fullyConnectedDescriptor.GetNumInputs();
2251*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), numInputs);
2252*89c4ff92SAndroid Build Coastguard Worker }
2253*89c4ff92SAndroid Build Coastguard Worker
2254*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2255*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2256*89c4ff92SAndroid Build Coastguard Worker
2257*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer, ignoreSlots);
2258*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2259*89c4ff92SAndroid Build Coastguard Worker }
2260*89c4ff92SAndroid Build Coastguard Worker
ParsePad(GraphPtr graph,unsigned int layerIndex)2261*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParsePad(GraphPtr graph, unsigned int layerIndex)
2262*89c4ff92SAndroid Build Coastguard Worker {
2263*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2264*89c4ff92SAndroid Build Coastguard Worker
2265*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2266*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2267*89c4ff92SAndroid Build Coastguard Worker
2268*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2269*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2270*89c4ff92SAndroid Build Coastguard Worker
2271*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_PadLayer()->descriptor();
2272*89c4ff92SAndroid Build Coastguard Worker auto flatBufferPadList = flatBufferDescriptor->padList();
2273*89c4ff92SAndroid Build Coastguard Worker auto paddingMode = flatBufferDescriptor->paddingMode();
2274*89c4ff92SAndroid Build Coastguard Worker float padValue = flatBufferDescriptor->padValue();
2275*89c4ff92SAndroid Build Coastguard Worker
2276*89c4ff92SAndroid Build Coastguard Worker if (flatBufferPadList->size() % 2 != 0)
2277*89c4ff92SAndroid Build Coastguard Worker {
2278*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("The size of the pad list must be divisible by 2 {}",
2279*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
2280*89c4ff92SAndroid Build Coastguard Worker }
2281*89c4ff92SAndroid Build Coastguard Worker
2282*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<unsigned int, unsigned int>> padList;
2283*89c4ff92SAndroid Build Coastguard Worker padList.reserve(flatBufferPadList->size() / 2);
2284*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < flatBufferPadList->size() - 1; i += 2)
2285*89c4ff92SAndroid Build Coastguard Worker {
2286*89c4ff92SAndroid Build Coastguard Worker padList.emplace_back(flatBufferPadList->Get(i), flatBufferPadList->Get(i+1));
2287*89c4ff92SAndroid Build Coastguard Worker }
2288*89c4ff92SAndroid Build Coastguard Worker
2289*89c4ff92SAndroid Build Coastguard Worker armnn::PadDescriptor descriptor(padList, padValue, ToPaddingMode(paddingMode));
2290*89c4ff92SAndroid Build Coastguard Worker
2291*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2292*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddPadLayer(descriptor, layerName.c_str());
2293*89c4ff92SAndroid Build Coastguard Worker
2294*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2295*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2296*89c4ff92SAndroid Build Coastguard Worker
2297*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2298*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2299*89c4ff92SAndroid Build Coastguard Worker }
2300*89c4ff92SAndroid Build Coastguard Worker
ParsePermute(GraphPtr graph,unsigned int layerIndex)2301*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParsePermute(GraphPtr graph, unsigned int layerIndex)
2302*89c4ff92SAndroid Build Coastguard Worker {
2303*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2304*89c4ff92SAndroid Build Coastguard Worker
2305*89c4ff92SAndroid Build Coastguard Worker auto dimsMapping =
2306*89c4ff92SAndroid Build Coastguard Worker graph->layers()->Get(layerIndex)->layer_as_PermuteLayer()->descriptor()->dimMappings();
2307*89c4ff92SAndroid Build Coastguard Worker
2308*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2309*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2310*89c4ff92SAndroid Build Coastguard Worker
2311*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2312*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2313*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
2314*89c4ff92SAndroid Build Coastguard Worker
2315*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2316*89c4ff92SAndroid Build Coastguard Worker const armnn::PermuteDescriptor descriptor(armnn::PermutationVector(dimsMapping->data(), dimsMapping->size()));
2317*89c4ff92SAndroid Build Coastguard Worker
2318*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddPermuteLayer(descriptor, layerName.c_str());
2319*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2320*89c4ff92SAndroid Build Coastguard Worker
2321*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2322*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2323*89c4ff92SAndroid Build Coastguard Worker }
2324*89c4ff92SAndroid Build Coastguard Worker
GetPooling2dDescriptor(Pooling2dDescriptor pooling2dDesc,unsigned int layerIndex)2325*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling2dDescriptor IDeserializer::DeserializerImpl::GetPooling2dDescriptor(Pooling2dDescriptor pooling2dDesc,
2326*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex)
2327*89c4ff92SAndroid Build Coastguard Worker {
2328*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(layerIndex);
2329*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling2dDescriptor desc;
2330*89c4ff92SAndroid Build Coastguard Worker
2331*89c4ff92SAndroid Build Coastguard Worker switch (pooling2dDesc->poolType())
2332*89c4ff92SAndroid Build Coastguard Worker {
2333*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm_Average:
2334*89c4ff92SAndroid Build Coastguard Worker {
2335*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::Average;
2336*89c4ff92SAndroid Build Coastguard Worker break;
2337*89c4ff92SAndroid Build Coastguard Worker }
2338*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm_Max:
2339*89c4ff92SAndroid Build Coastguard Worker {
2340*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::Max;
2341*89c4ff92SAndroid Build Coastguard Worker break;
2342*89c4ff92SAndroid Build Coastguard Worker }
2343*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm_L2:
2344*89c4ff92SAndroid Build Coastguard Worker {
2345*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::L2;
2346*89c4ff92SAndroid Build Coastguard Worker break;
2347*89c4ff92SAndroid Build Coastguard Worker }
2348*89c4ff92SAndroid Build Coastguard Worker default:
2349*89c4ff92SAndroid Build Coastguard Worker {
2350*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported pooling algorithm");
2351*89c4ff92SAndroid Build Coastguard Worker }
2352*89c4ff92SAndroid Build Coastguard Worker }
2353*89c4ff92SAndroid Build Coastguard Worker
2354*89c4ff92SAndroid Build Coastguard Worker switch (pooling2dDesc->outputShapeRounding())
2355*89c4ff92SAndroid Build Coastguard Worker {
2356*89c4ff92SAndroid Build Coastguard Worker case OutputShapeRounding_Floor:
2357*89c4ff92SAndroid Build Coastguard Worker {
2358*89c4ff92SAndroid Build Coastguard Worker desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
2359*89c4ff92SAndroid Build Coastguard Worker break;
2360*89c4ff92SAndroid Build Coastguard Worker }
2361*89c4ff92SAndroid Build Coastguard Worker case OutputShapeRounding_Ceiling:
2362*89c4ff92SAndroid Build Coastguard Worker {
2363*89c4ff92SAndroid Build Coastguard Worker desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Ceiling;
2364*89c4ff92SAndroid Build Coastguard Worker break;
2365*89c4ff92SAndroid Build Coastguard Worker }
2366*89c4ff92SAndroid Build Coastguard Worker default:
2367*89c4ff92SAndroid Build Coastguard Worker {
2368*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported output shape rounding");
2369*89c4ff92SAndroid Build Coastguard Worker }
2370*89c4ff92SAndroid Build Coastguard Worker }
2371*89c4ff92SAndroid Build Coastguard Worker
2372*89c4ff92SAndroid Build Coastguard Worker switch (pooling2dDesc->paddingMethod())
2373*89c4ff92SAndroid Build Coastguard Worker {
2374*89c4ff92SAndroid Build Coastguard Worker case PaddingMethod_Exclude:
2375*89c4ff92SAndroid Build Coastguard Worker {
2376*89c4ff92SAndroid Build Coastguard Worker desc.m_PaddingMethod = armnn::PaddingMethod::Exclude;
2377*89c4ff92SAndroid Build Coastguard Worker break;
2378*89c4ff92SAndroid Build Coastguard Worker }
2379*89c4ff92SAndroid Build Coastguard Worker case PaddingMethod_IgnoreValue:
2380*89c4ff92SAndroid Build Coastguard Worker {
2381*89c4ff92SAndroid Build Coastguard Worker desc.m_PaddingMethod = armnn::PaddingMethod::IgnoreValue;
2382*89c4ff92SAndroid Build Coastguard Worker break;
2383*89c4ff92SAndroid Build Coastguard Worker }
2384*89c4ff92SAndroid Build Coastguard Worker default:
2385*89c4ff92SAndroid Build Coastguard Worker {
2386*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported padding method");
2387*89c4ff92SAndroid Build Coastguard Worker }
2388*89c4ff92SAndroid Build Coastguard Worker }
2389*89c4ff92SAndroid Build Coastguard Worker
2390*89c4ff92SAndroid Build Coastguard Worker switch (pooling2dDesc->dataLayout())
2391*89c4ff92SAndroid Build Coastguard Worker {
2392*89c4ff92SAndroid Build Coastguard Worker case DataLayout_NCHW:
2393*89c4ff92SAndroid Build Coastguard Worker {
2394*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NCHW;
2395*89c4ff92SAndroid Build Coastguard Worker break;
2396*89c4ff92SAndroid Build Coastguard Worker }
2397*89c4ff92SAndroid Build Coastguard Worker case DataLayout_NHWC:
2398*89c4ff92SAndroid Build Coastguard Worker {
2399*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NHWC;
2400*89c4ff92SAndroid Build Coastguard Worker break;
2401*89c4ff92SAndroid Build Coastguard Worker }
2402*89c4ff92SAndroid Build Coastguard Worker default:
2403*89c4ff92SAndroid Build Coastguard Worker {
2404*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported data layout");
2405*89c4ff92SAndroid Build Coastguard Worker }
2406*89c4ff92SAndroid Build Coastguard Worker }
2407*89c4ff92SAndroid Build Coastguard Worker
2408*89c4ff92SAndroid Build Coastguard Worker desc.m_PadRight = pooling2dDesc->padRight();
2409*89c4ff92SAndroid Build Coastguard Worker desc.m_PadLeft = pooling2dDesc->padLeft();
2410*89c4ff92SAndroid Build Coastguard Worker desc.m_PadBottom = pooling2dDesc->padBottom();
2411*89c4ff92SAndroid Build Coastguard Worker desc.m_PadTop = pooling2dDesc->padTop();
2412*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = pooling2dDesc->strideX();
2413*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideY = pooling2dDesc->strideY();
2414*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolWidth = pooling2dDesc->poolWidth();
2415*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolHeight = pooling2dDesc->poolHeight();
2416*89c4ff92SAndroid Build Coastguard Worker
2417*89c4ff92SAndroid Build Coastguard Worker return desc;
2418*89c4ff92SAndroid Build Coastguard Worker }
2419*89c4ff92SAndroid Build Coastguard Worker
GetPooling3dDescriptor(Pooling3dDescriptor pooling3dDesc,unsigned int layerIndex)2420*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling3dDescriptor IDeserializer::DeserializerImpl::GetPooling3dDescriptor(Pooling3dDescriptor pooling3dDesc,
2421*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex)
2422*89c4ff92SAndroid Build Coastguard Worker {
2423*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(layerIndex);
2424*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling3dDescriptor desc;
2425*89c4ff92SAndroid Build Coastguard Worker
2426*89c4ff92SAndroid Build Coastguard Worker switch (pooling3dDesc->poolType())
2427*89c4ff92SAndroid Build Coastguard Worker {
2428*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm_Average:
2429*89c4ff92SAndroid Build Coastguard Worker {
2430*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::Average;
2431*89c4ff92SAndroid Build Coastguard Worker break;
2432*89c4ff92SAndroid Build Coastguard Worker }
2433*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm_Max:
2434*89c4ff92SAndroid Build Coastguard Worker {
2435*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::Max;
2436*89c4ff92SAndroid Build Coastguard Worker break;
2437*89c4ff92SAndroid Build Coastguard Worker }
2438*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm_L2:
2439*89c4ff92SAndroid Build Coastguard Worker {
2440*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::L2;
2441*89c4ff92SAndroid Build Coastguard Worker break;
2442*89c4ff92SAndroid Build Coastguard Worker }
2443*89c4ff92SAndroid Build Coastguard Worker default:
2444*89c4ff92SAndroid Build Coastguard Worker {
2445*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported pooling algorithm");
2446*89c4ff92SAndroid Build Coastguard Worker }
2447*89c4ff92SAndroid Build Coastguard Worker }
2448*89c4ff92SAndroid Build Coastguard Worker
2449*89c4ff92SAndroid Build Coastguard Worker switch (pooling3dDesc->outputShapeRounding())
2450*89c4ff92SAndroid Build Coastguard Worker {
2451*89c4ff92SAndroid Build Coastguard Worker case OutputShapeRounding_Floor:
2452*89c4ff92SAndroid Build Coastguard Worker {
2453*89c4ff92SAndroid Build Coastguard Worker desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
2454*89c4ff92SAndroid Build Coastguard Worker break;
2455*89c4ff92SAndroid Build Coastguard Worker }
2456*89c4ff92SAndroid Build Coastguard Worker case OutputShapeRounding_Ceiling:
2457*89c4ff92SAndroid Build Coastguard Worker {
2458*89c4ff92SAndroid Build Coastguard Worker desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Ceiling;
2459*89c4ff92SAndroid Build Coastguard Worker break;
2460*89c4ff92SAndroid Build Coastguard Worker }
2461*89c4ff92SAndroid Build Coastguard Worker default:
2462*89c4ff92SAndroid Build Coastguard Worker {
2463*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported output shape rounding");
2464*89c4ff92SAndroid Build Coastguard Worker }
2465*89c4ff92SAndroid Build Coastguard Worker }
2466*89c4ff92SAndroid Build Coastguard Worker
2467*89c4ff92SAndroid Build Coastguard Worker switch (pooling3dDesc->paddingMethod())
2468*89c4ff92SAndroid Build Coastguard Worker {
2469*89c4ff92SAndroid Build Coastguard Worker case PaddingMethod_Exclude:
2470*89c4ff92SAndroid Build Coastguard Worker {
2471*89c4ff92SAndroid Build Coastguard Worker desc.m_PaddingMethod = armnn::PaddingMethod::Exclude;
2472*89c4ff92SAndroid Build Coastguard Worker break;
2473*89c4ff92SAndroid Build Coastguard Worker }
2474*89c4ff92SAndroid Build Coastguard Worker case PaddingMethod_IgnoreValue:
2475*89c4ff92SAndroid Build Coastguard Worker {
2476*89c4ff92SAndroid Build Coastguard Worker desc.m_PaddingMethod = armnn::PaddingMethod::IgnoreValue;
2477*89c4ff92SAndroid Build Coastguard Worker break;
2478*89c4ff92SAndroid Build Coastguard Worker }
2479*89c4ff92SAndroid Build Coastguard Worker default:
2480*89c4ff92SAndroid Build Coastguard Worker {
2481*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported padding method");
2482*89c4ff92SAndroid Build Coastguard Worker }
2483*89c4ff92SAndroid Build Coastguard Worker }
2484*89c4ff92SAndroid Build Coastguard Worker
2485*89c4ff92SAndroid Build Coastguard Worker switch (pooling3dDesc->dataLayout())
2486*89c4ff92SAndroid Build Coastguard Worker {
2487*89c4ff92SAndroid Build Coastguard Worker case DataLayout_NCDHW:
2488*89c4ff92SAndroid Build Coastguard Worker {
2489*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NCDHW;
2490*89c4ff92SAndroid Build Coastguard Worker break;
2491*89c4ff92SAndroid Build Coastguard Worker }
2492*89c4ff92SAndroid Build Coastguard Worker case DataLayout_NDHWC:
2493*89c4ff92SAndroid Build Coastguard Worker {
2494*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NDHWC;
2495*89c4ff92SAndroid Build Coastguard Worker break;
2496*89c4ff92SAndroid Build Coastguard Worker }
2497*89c4ff92SAndroid Build Coastguard Worker default:
2498*89c4ff92SAndroid Build Coastguard Worker {
2499*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported data layout");
2500*89c4ff92SAndroid Build Coastguard Worker }
2501*89c4ff92SAndroid Build Coastguard Worker }
2502*89c4ff92SAndroid Build Coastguard Worker
2503*89c4ff92SAndroid Build Coastguard Worker desc.m_PadRight = pooling3dDesc->padRight();
2504*89c4ff92SAndroid Build Coastguard Worker desc.m_PadLeft = pooling3dDesc->padLeft();
2505*89c4ff92SAndroid Build Coastguard Worker desc.m_PadBottom = pooling3dDesc->padBottom();
2506*89c4ff92SAndroid Build Coastguard Worker desc.m_PadTop = pooling3dDesc->padTop();
2507*89c4ff92SAndroid Build Coastguard Worker desc.m_PadFront = pooling3dDesc->padFront();
2508*89c4ff92SAndroid Build Coastguard Worker desc.m_PadBack = pooling3dDesc->padBack();
2509*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = pooling3dDesc->strideX();
2510*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideY = pooling3dDesc->strideY();
2511*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideZ = pooling3dDesc->strideZ();
2512*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolWidth = pooling3dDesc->poolWidth();
2513*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolHeight = pooling3dDesc->poolHeight();
2514*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolDepth = pooling3dDesc->poolDepth();
2515*89c4ff92SAndroid Build Coastguard Worker
2516*89c4ff92SAndroid Build Coastguard Worker return desc;
2517*89c4ff92SAndroid Build Coastguard Worker }
2518*89c4ff92SAndroid Build Coastguard Worker
ParsePooling2d(GraphPtr graph,unsigned int layerIndex)2519*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParsePooling2d(GraphPtr graph, unsigned int layerIndex)
2520*89c4ff92SAndroid Build Coastguard Worker {
2521*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2522*89c4ff92SAndroid Build Coastguard Worker
2523*89c4ff92SAndroid Build Coastguard Worker auto pooling2dDes = graph->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->descriptor();
2524*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2525*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2526*89c4ff92SAndroid Build Coastguard Worker
2527*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2528*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2529*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
2530*89c4ff92SAndroid Build Coastguard Worker
2531*89c4ff92SAndroid Build Coastguard Worker auto pooling2dDescriptor = GetPooling2dDescriptor(pooling2dDes, layerIndex);
2532*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2533*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, layerName.c_str());
2534*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2535*89c4ff92SAndroid Build Coastguard Worker
2536*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2537*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2538*89c4ff92SAndroid Build Coastguard Worker }
2539*89c4ff92SAndroid Build Coastguard Worker
ParsePooling3d(GraphPtr graph,unsigned int layerIndex)2540*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParsePooling3d(GraphPtr graph, unsigned int layerIndex)
2541*89c4ff92SAndroid Build Coastguard Worker {
2542*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2543*89c4ff92SAndroid Build Coastguard Worker
2544*89c4ff92SAndroid Build Coastguard Worker auto pooling3dDes = graph->layers()->Get(layerIndex)->layer_as_Pooling3dLayer()->descriptor();
2545*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2546*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2547*89c4ff92SAndroid Build Coastguard Worker
2548*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2549*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2550*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
2551*89c4ff92SAndroid Build Coastguard Worker
2552*89c4ff92SAndroid Build Coastguard Worker auto pooling3dDescriptor = GetPooling3dDescriptor(pooling3dDes, layerIndex);
2553*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2554*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddPooling3dLayer(pooling3dDescriptor, layerName.c_str());
2555*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2556*89c4ff92SAndroid Build Coastguard Worker
2557*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2558*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2559*89c4ff92SAndroid Build Coastguard Worker }
2560*89c4ff92SAndroid Build Coastguard Worker
ParseQuantize(GraphPtr graph,unsigned int layerIndex)2561*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseQuantize(GraphPtr graph, unsigned int layerIndex)
2562*89c4ff92SAndroid Build Coastguard Worker {
2563*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2564*89c4ff92SAndroid Build Coastguard Worker
2565*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2566*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2567*89c4ff92SAndroid Build Coastguard Worker
2568*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2569*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2570*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
2571*89c4ff92SAndroid Build Coastguard Worker
2572*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2573*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddQuantizeLayer(layerName.c_str());
2574*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2575*89c4ff92SAndroid Build Coastguard Worker
2576*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2577*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2578*89c4ff92SAndroid Build Coastguard Worker }
2579*89c4ff92SAndroid Build Coastguard Worker
OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,const std::vector<uint32_t> & targetDimsIn)2580*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo IDeserializer::DeserializerImpl::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
2581*89c4ff92SAndroid Build Coastguard Worker const std::vector<uint32_t>& targetDimsIn)
2582*89c4ff92SAndroid Build Coastguard Worker {
2583*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
2584*89c4ff92SAndroid Build Coastguard Worker const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
2585*89c4ff92SAndroid Build Coastguard Worker
2586*89c4ff92SAndroid Build Coastguard Worker if (stretchDim != targetDimsIn.end())
2587*89c4ff92SAndroid Build Coastguard Worker {
2588*89c4ff92SAndroid Build Coastguard Worker if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
2589*89c4ff92SAndroid Build Coastguard Worker {
2590*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("At most one component of shape can be -1 {}",
2591*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
2592*89c4ff92SAndroid Build Coastguard Worker }
2593*89c4ff92SAndroid Build Coastguard Worker
2594*89c4ff92SAndroid Build Coastguard Worker auto targetNumElements =
2595*89c4ff92SAndroid Build Coastguard Worker armnn::numeric_cast<unsigned int>(
2596*89c4ff92SAndroid Build Coastguard Worker std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
2597*89c4ff92SAndroid Build Coastguard Worker
2598*89c4ff92SAndroid Build Coastguard Worker auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
2599*89c4ff92SAndroid Build Coastguard Worker outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
2600*89c4ff92SAndroid Build Coastguard Worker }
2601*89c4ff92SAndroid Build Coastguard Worker
2602*89c4ff92SAndroid Build Coastguard Worker TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
2603*89c4ff92SAndroid Build Coastguard Worker
2604*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo reshapeInfo = inputTensorInfo;
2605*89c4ff92SAndroid Build Coastguard Worker reshapeInfo.SetShape(outputShape);
2606*89c4ff92SAndroid Build Coastguard Worker
2607*89c4ff92SAndroid Build Coastguard Worker return reshapeInfo;
2608*89c4ff92SAndroid Build Coastguard Worker }
2609*89c4ff92SAndroid Build Coastguard Worker
ParseRank(GraphPtr graph,unsigned int layerIndex)2610*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseRank(GraphPtr graph, unsigned int layerIndex)
2611*89c4ff92SAndroid Build Coastguard Worker {
2612*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2613*89c4ff92SAndroid Build Coastguard Worker
2614*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2615*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2616*89c4ff92SAndroid Build Coastguard Worker
2617*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2618*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2619*89c4ff92SAndroid Build Coastguard Worker
2620*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2621*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddRankLayer( layerName.c_str());
2622*89c4ff92SAndroid Build Coastguard Worker
2623*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2624*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2625*89c4ff92SAndroid Build Coastguard Worker
2626*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2627*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2628*89c4ff92SAndroid Build Coastguard Worker }
2629*89c4ff92SAndroid Build Coastguard Worker
ParseReduce(GraphPtr graph,unsigned int layerIndex)2630*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseReduce(GraphPtr graph, unsigned int layerIndex)
2631*89c4ff92SAndroid Build Coastguard Worker {
2632*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2633*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2634*89c4ff92SAndroid Build Coastguard Worker
2635*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2636*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2637*89c4ff92SAndroid Build Coastguard Worker
2638*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2639*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2640*89c4ff92SAndroid Build Coastguard Worker
2641*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ReduceLayer();
2642*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
2643*89c4ff92SAndroid Build Coastguard Worker auto flatBufferAxis = fbDescriptor->axis();
2644*89c4ff92SAndroid Build Coastguard Worker
2645*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor descriptor;
2646*89c4ff92SAndroid Build Coastguard Worker descriptor.m_KeepDims = fbDescriptor->keepDims();
2647*89c4ff92SAndroid Build Coastguard Worker descriptor.m_vAxis = std::vector<unsigned int>(flatBufferAxis->begin(), flatBufferAxis->end());
2648*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ReduceOperation = ToReduceOperation(fbDescriptor->reduceOperation());
2649*89c4ff92SAndroid Build Coastguard Worker
2650*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName = GetLayerName(graph, layerIndex);
2651*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddReduceLayer(descriptor, layerName.c_str());
2652*89c4ff92SAndroid Build Coastguard Worker
2653*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2654*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2655*89c4ff92SAndroid Build Coastguard Worker
2656*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2657*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2658*89c4ff92SAndroid Build Coastguard Worker }
2659*89c4ff92SAndroid Build Coastguard Worker
ParseReshape(GraphPtr graph,unsigned int layerIndex)2660*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseReshape(GraphPtr graph, unsigned int layerIndex)
2661*89c4ff92SAndroid Build Coastguard Worker {
2662*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2663*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2664*89c4ff92SAndroid Build Coastguard Worker
2665*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2666*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2667*89c4ff92SAndroid Build Coastguard Worker
2668*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
2669*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
2670*89c4ff92SAndroid Build Coastguard Worker
2671*89c4ff92SAndroid Build Coastguard Worker const auto targetDims = graph->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->descriptor()->targetShape();
2672*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> outputDims(targetDims->begin(), targetDims->begin() + targetDims->size());
2673*89c4ff92SAndroid Build Coastguard Worker
2674*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo reshapeOutputTensorInfo = DeserializerImpl::OutputShapeOfReshape(inputTensorInfo, outputDims);
2675*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
2676*89c4ff92SAndroid Build Coastguard Worker
2677*89c4ff92SAndroid Build Coastguard Worker const std::vector<uint32_t> expectedDims(outputs[0]->dimensions()->begin(),
2678*89c4ff92SAndroid Build Coastguard Worker outputs[0]->dimensions()->begin() + outputs[0]->dimensions()->size());
2679*89c4ff92SAndroid Build Coastguard Worker
2680*89c4ff92SAndroid Build Coastguard Worker if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, expectedDims))
2681*89c4ff92SAndroid Build Coastguard Worker {
2682*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
2683*89c4ff92SAndroid Build Coastguard Worker ss << "New shape defined in reshape parameters "
2684*89c4ff92SAndroid Build Coastguard Worker << reshapeOutputTensorShape
2685*89c4ff92SAndroid Build Coastguard Worker << " does not equal output shape "
2686*89c4ff92SAndroid Build Coastguard Worker << actualOutputTensorInfo.GetShape()
2687*89c4ff92SAndroid Build Coastguard Worker << ": "
2688*89c4ff92SAndroid Build Coastguard Worker << CHECK_LOCATION().AsString();
2689*89c4ff92SAndroid Build Coastguard Worker throw ParseException(ss.str());
2690*89c4ff92SAndroid Build Coastguard Worker }
2691*89c4ff92SAndroid Build Coastguard Worker
2692*89c4ff92SAndroid Build Coastguard Worker armnn::ReshapeDescriptor reshapeDesc;
2693*89c4ff92SAndroid Build Coastguard Worker reshapeDesc.m_TargetShape = reshapeOutputTensorShape;
2694*89c4ff92SAndroid Build Coastguard Worker
2695*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2696*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
2697*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
2698*89c4ff92SAndroid Build Coastguard Worker
2699*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2700*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2701*89c4ff92SAndroid Build Coastguard Worker }
2702*89c4ff92SAndroid Build Coastguard Worker
ParseResize(GraphPtr graph,unsigned int layerIndex)2703*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseResize(GraphPtr graph, unsigned int layerIndex)
2704*89c4ff92SAndroid Build Coastguard Worker {
2705*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2706*89c4ff92SAndroid Build Coastguard Worker
2707*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2708*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2709*89c4ff92SAndroid Build Coastguard Worker
2710*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2711*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2712*89c4ff92SAndroid Build Coastguard Worker
2713*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_ResizeLayer()->descriptor();
2714*89c4ff92SAndroid Build Coastguard Worker
2715*89c4ff92SAndroid Build Coastguard Worker armnn::ResizeDescriptor descriptor;
2716*89c4ff92SAndroid Build Coastguard Worker descriptor.m_TargetWidth = flatBufferDescriptor->targetWidth();
2717*89c4ff92SAndroid Build Coastguard Worker descriptor.m_TargetHeight = flatBufferDescriptor->targetHeight();
2718*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Method = ToResizeMethod(flatBufferDescriptor->method());
2719*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
2720*89c4ff92SAndroid Build Coastguard Worker descriptor.m_AlignCorners = flatBufferDescriptor->alignCorners();
2721*89c4ff92SAndroid Build Coastguard Worker descriptor.m_HalfPixelCenters = flatBufferDescriptor->halfPixelCenters();
2722*89c4ff92SAndroid Build Coastguard Worker
2723*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2724*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddResizeLayer(descriptor, layerName.c_str());
2725*89c4ff92SAndroid Build Coastguard Worker
2726*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2727*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2728*89c4ff92SAndroid Build Coastguard Worker
2729*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2730*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2731*89c4ff92SAndroid Build Coastguard Worker }
2732*89c4ff92SAndroid Build Coastguard Worker
2733*89c4ff92SAndroid Build Coastguard Worker
2734*89c4ff92SAndroid Build Coastguard Worker /// @Note The ResizeBiliniar operation was deprecated and removed in favor of the Resize operation.
2735*89c4ff92SAndroid Build Coastguard Worker /// This function is kept for backwards compatibility.
ParseResizeBilinear(GraphPtr graph,unsigned int layerIndex)2736*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex)
2737*89c4ff92SAndroid Build Coastguard Worker {
2738*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2739*89c4ff92SAndroid Build Coastguard Worker
2740*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2741*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2742*89c4ff92SAndroid Build Coastguard Worker
2743*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2744*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2745*89c4ff92SAndroid Build Coastguard Worker
2746*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_ResizeBilinearLayer()->descriptor();
2747*89c4ff92SAndroid Build Coastguard Worker
2748*89c4ff92SAndroid Build Coastguard Worker armnn::ResizeDescriptor descriptor;
2749*89c4ff92SAndroid Build Coastguard Worker descriptor.m_TargetWidth = flatBufferDescriptor->targetWidth();
2750*89c4ff92SAndroid Build Coastguard Worker descriptor.m_TargetHeight = flatBufferDescriptor->targetHeight();
2751*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Method = armnn::ResizeMethod::Bilinear;
2752*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
2753*89c4ff92SAndroid Build Coastguard Worker descriptor.m_AlignCorners = flatBufferDescriptor->alignCorners();
2754*89c4ff92SAndroid Build Coastguard Worker descriptor.m_HalfPixelCenters = flatBufferDescriptor->halfPixelCenters();
2755*89c4ff92SAndroid Build Coastguard Worker
2756*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2757*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddResizeLayer(descriptor, layerName.c_str());
2758*89c4ff92SAndroid Build Coastguard Worker
2759*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2760*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2761*89c4ff92SAndroid Build Coastguard Worker
2762*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2763*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2764*89c4ff92SAndroid Build Coastguard Worker }
2765*89c4ff92SAndroid Build Coastguard Worker
ParseShape(GraphPtr graph,unsigned int layerIndex)2766*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseShape(GraphPtr graph, unsigned int layerIndex)
2767*89c4ff92SAndroid Build Coastguard Worker {
2768*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2769*89c4ff92SAndroid Build Coastguard Worker
2770*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2771*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2772*89c4ff92SAndroid Build Coastguard Worker
2773*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2774*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2775*89c4ff92SAndroid Build Coastguard Worker
2776*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2777*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddShapeLayer( layerName.c_str());
2778*89c4ff92SAndroid Build Coastguard Worker
2779*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2780*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2781*89c4ff92SAndroid Build Coastguard Worker
2782*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2783*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2784*89c4ff92SAndroid Build Coastguard Worker }
2785*89c4ff92SAndroid Build Coastguard Worker
ParseSoftmax(GraphPtr graph,unsigned int layerIndex)2786*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSoftmax(GraphPtr graph, unsigned int layerIndex)
2787*89c4ff92SAndroid Build Coastguard Worker {
2788*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2789*89c4ff92SAndroid Build Coastguard Worker
2790*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2791*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2792*89c4ff92SAndroid Build Coastguard Worker
2793*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2794*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2795*89c4ff92SAndroid Build Coastguard Worker
2796*89c4ff92SAndroid Build Coastguard Worker armnn::SoftmaxDescriptor descriptor;
2797*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Beta = graph->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->descriptor()->beta();
2798*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->descriptor()->axis();
2799*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2800*89c4ff92SAndroid Build Coastguard Worker
2801*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddSoftmaxLayer(descriptor, layerName.c_str());
2802*89c4ff92SAndroid Build Coastguard Worker
2803*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2804*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2805*89c4ff92SAndroid Build Coastguard Worker
2806*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2807*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2808*89c4ff92SAndroid Build Coastguard Worker }
2809*89c4ff92SAndroid Build Coastguard Worker
ParseSpaceToBatchNd(GraphPtr graph,unsigned int layerIndex)2810*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex)
2811*89c4ff92SAndroid Build Coastguard Worker {
2812*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2813*89c4ff92SAndroid Build Coastguard Worker
2814*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2815*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2816*89c4ff92SAndroid Build Coastguard Worker
2817*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2818*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2819*89c4ff92SAndroid Build Coastguard Worker
2820*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_SpaceToBatchNdLayer()->descriptor();
2821*89c4ff92SAndroid Build Coastguard Worker auto flatBufferPadList = flatBufferDescriptor->padList();
2822*89c4ff92SAndroid Build Coastguard Worker auto flatBufferBlockShape = flatBufferDescriptor->blockShape();
2823*89c4ff92SAndroid Build Coastguard Worker
2824*89c4ff92SAndroid Build Coastguard Worker if (flatBufferPadList->size() % 2 != 0)
2825*89c4ff92SAndroid Build Coastguard Worker {
2826*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("The size of the pad list must be divisible by 2 {}",
2827*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
2828*89c4ff92SAndroid Build Coastguard Worker }
2829*89c4ff92SAndroid Build Coastguard Worker
2830*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<unsigned int, unsigned int>> padList;
2831*89c4ff92SAndroid Build Coastguard Worker padList.reserve(flatBufferPadList->size() / 2);
2832*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < flatBufferPadList->size() - 1; i += 2)
2833*89c4ff92SAndroid Build Coastguard Worker {
2834*89c4ff92SAndroid Build Coastguard Worker padList.emplace_back(flatBufferPadList->Get(i), flatBufferPadList->Get(i+1));
2835*89c4ff92SAndroid Build Coastguard Worker }
2836*89c4ff92SAndroid Build Coastguard Worker
2837*89c4ff92SAndroid Build Coastguard Worker armnn::SpaceToBatchNdDescriptor descriptor;
2838*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
2839*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BlockShape =
2840*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int>(flatBufferBlockShape->begin(), flatBufferBlockShape->end());
2841*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadList = padList;
2842*89c4ff92SAndroid Build Coastguard Worker
2843*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2844*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddSpaceToBatchNdLayer(descriptor, layerName.c_str());
2845*89c4ff92SAndroid Build Coastguard Worker
2846*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2847*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2848*89c4ff92SAndroid Build Coastguard Worker
2849*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2850*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2851*89c4ff92SAndroid Build Coastguard Worker }
2852*89c4ff92SAndroid Build Coastguard Worker
ParseSpaceToDepth(GraphPtr graph,unsigned int layerIndex)2853*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSpaceToDepth(GraphPtr graph, unsigned int layerIndex)
2854*89c4ff92SAndroid Build Coastguard Worker {
2855*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2856*89c4ff92SAndroid Build Coastguard Worker
2857*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2858*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2859*89c4ff92SAndroid Build Coastguard Worker
2860*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2861*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2862*89c4ff92SAndroid Build Coastguard Worker
2863*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_SpaceToDepthLayer()->descriptor();
2864*89c4ff92SAndroid Build Coastguard Worker
2865*89c4ff92SAndroid Build Coastguard Worker armnn::SpaceToDepthDescriptor descriptor;
2866*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BlockSize = flatBufferDescriptor->blockSize();
2867*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
2868*89c4ff92SAndroid Build Coastguard Worker
2869*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2870*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddSpaceToDepthLayer(descriptor, layerName.c_str());
2871*89c4ff92SAndroid Build Coastguard Worker
2872*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2873*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2874*89c4ff92SAndroid Build Coastguard Worker
2875*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2876*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2877*89c4ff92SAndroid Build Coastguard Worker }
2878*89c4ff92SAndroid Build Coastguard Worker
GetNormalizationDescriptor(NormalizationDescriptorPtr normalizationDescriptor,unsigned int layerIndex)2879*89c4ff92SAndroid Build Coastguard Worker armnn::NormalizationDescriptor IDeserializer::DeserializerImpl::GetNormalizationDescriptor(
2880*89c4ff92SAndroid Build Coastguard Worker NormalizationDescriptorPtr normalizationDescriptor,
2881*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex)
2882*89c4ff92SAndroid Build Coastguard Worker {
2883*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(layerIndex);
2884*89c4ff92SAndroid Build Coastguard Worker armnn::NormalizationDescriptor desc;
2885*89c4ff92SAndroid Build Coastguard Worker
2886*89c4ff92SAndroid Build Coastguard Worker switch (normalizationDescriptor->normChannelType())
2887*89c4ff92SAndroid Build Coastguard Worker {
2888*89c4ff92SAndroid Build Coastguard Worker case NormalizationAlgorithmChannel_Across:
2889*89c4ff92SAndroid Build Coastguard Worker {
2890*89c4ff92SAndroid Build Coastguard Worker desc.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
2891*89c4ff92SAndroid Build Coastguard Worker break;
2892*89c4ff92SAndroid Build Coastguard Worker }
2893*89c4ff92SAndroid Build Coastguard Worker case NormalizationAlgorithmChannel_Within:
2894*89c4ff92SAndroid Build Coastguard Worker {
2895*89c4ff92SAndroid Build Coastguard Worker desc.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Within;
2896*89c4ff92SAndroid Build Coastguard Worker break;
2897*89c4ff92SAndroid Build Coastguard Worker }
2898*89c4ff92SAndroid Build Coastguard Worker default:
2899*89c4ff92SAndroid Build Coastguard Worker {
2900*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported normalization channel type");
2901*89c4ff92SAndroid Build Coastguard Worker }
2902*89c4ff92SAndroid Build Coastguard Worker }
2903*89c4ff92SAndroid Build Coastguard Worker
2904*89c4ff92SAndroid Build Coastguard Worker switch (normalizationDescriptor->normMethodType())
2905*89c4ff92SAndroid Build Coastguard Worker {
2906*89c4ff92SAndroid Build Coastguard Worker case NormalizationAlgorithmMethod_LocalBrightness:
2907*89c4ff92SAndroid Build Coastguard Worker {
2908*89c4ff92SAndroid Build Coastguard Worker desc.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalBrightness;
2909*89c4ff92SAndroid Build Coastguard Worker break;
2910*89c4ff92SAndroid Build Coastguard Worker }
2911*89c4ff92SAndroid Build Coastguard Worker case NormalizationAlgorithmMethod_LocalContrast:
2912*89c4ff92SAndroid Build Coastguard Worker {
2913*89c4ff92SAndroid Build Coastguard Worker desc.m_NormMethodType = armnn::NormalizationAlgorithmMethod::LocalContrast;
2914*89c4ff92SAndroid Build Coastguard Worker break;
2915*89c4ff92SAndroid Build Coastguard Worker }
2916*89c4ff92SAndroid Build Coastguard Worker default:
2917*89c4ff92SAndroid Build Coastguard Worker {
2918*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported normalization method type");
2919*89c4ff92SAndroid Build Coastguard Worker }
2920*89c4ff92SAndroid Build Coastguard Worker }
2921*89c4ff92SAndroid Build Coastguard Worker
2922*89c4ff92SAndroid Build Coastguard Worker switch (normalizationDescriptor->dataLayout())
2923*89c4ff92SAndroid Build Coastguard Worker {
2924*89c4ff92SAndroid Build Coastguard Worker case DataLayout_NCHW:
2925*89c4ff92SAndroid Build Coastguard Worker {
2926*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NCHW;
2927*89c4ff92SAndroid Build Coastguard Worker break;
2928*89c4ff92SAndroid Build Coastguard Worker }
2929*89c4ff92SAndroid Build Coastguard Worker case DataLayout_NHWC:
2930*89c4ff92SAndroid Build Coastguard Worker {
2931*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NHWC;
2932*89c4ff92SAndroid Build Coastguard Worker break;
2933*89c4ff92SAndroid Build Coastguard Worker }
2934*89c4ff92SAndroid Build Coastguard Worker default:
2935*89c4ff92SAndroid Build Coastguard Worker {
2936*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "Unsupported data layout");
2937*89c4ff92SAndroid Build Coastguard Worker }
2938*89c4ff92SAndroid Build Coastguard Worker }
2939*89c4ff92SAndroid Build Coastguard Worker
2940*89c4ff92SAndroid Build Coastguard Worker desc.m_Alpha = normalizationDescriptor->alpha();
2941*89c4ff92SAndroid Build Coastguard Worker desc.m_Beta = normalizationDescriptor->beta();
2942*89c4ff92SAndroid Build Coastguard Worker desc.m_K = normalizationDescriptor->k();
2943*89c4ff92SAndroid Build Coastguard Worker desc.m_NormSize = normalizationDescriptor->normSize();
2944*89c4ff92SAndroid Build Coastguard Worker
2945*89c4ff92SAndroid Build Coastguard Worker return desc;
2946*89c4ff92SAndroid Build Coastguard Worker }
2947*89c4ff92SAndroid Build Coastguard Worker
ParseNormalization(GraphPtr graph,unsigned int layerIndex)2948*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseNormalization(GraphPtr graph, unsigned int layerIndex)
2949*89c4ff92SAndroid Build Coastguard Worker {
2950*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2951*89c4ff92SAndroid Build Coastguard Worker
2952*89c4ff92SAndroid Build Coastguard Worker auto normalizationDes = graph->layers()->Get(layerIndex)->layer_as_NormalizationLayer()->descriptor();
2953*89c4ff92SAndroid Build Coastguard Worker
2954*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
2955*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2956*89c4ff92SAndroid Build Coastguard Worker
2957*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
2958*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2959*89c4ff92SAndroid Build Coastguard Worker
2960*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
2961*89c4ff92SAndroid Build Coastguard Worker
2962*89c4ff92SAndroid Build Coastguard Worker auto normalizationDescriptor = GetNormalizationDescriptor(normalizationDes, layerIndex);
2963*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2964*89c4ff92SAndroid Build Coastguard Worker
2965*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddNormalizationLayer(normalizationDescriptor, layerName.c_str());
2966*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2967*89c4ff92SAndroid Build Coastguard Worker
2968*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2969*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2970*89c4ff92SAndroid Build Coastguard Worker }
2971*89c4ff92SAndroid Build Coastguard Worker
ParseRsqrt(GraphPtr graph,unsigned int layerIndex)2972*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseRsqrt(GraphPtr graph, unsigned int layerIndex)
2973*89c4ff92SAndroid Build Coastguard Worker {
2974*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2975*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2976*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
2977*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2978*89c4ff92SAndroid Build Coastguard Worker
2979*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
2980*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
2981*89c4ff92SAndroid Build Coastguard Worker
2982*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
2983*89c4ff92SAndroid Build Coastguard Worker
2984*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseUnaryDescriptor descriptor(armnn::UnaryOperation::Rsqrt);
2985*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str());
2986*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
2987*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2988*89c4ff92SAndroid Build Coastguard Worker
2989*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
2990*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
2991*89c4ff92SAndroid Build Coastguard Worker }
2992*89c4ff92SAndroid Build Coastguard Worker
ParseSlice(GraphPtr graph,unsigned int layerIndex)2993*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSlice(GraphPtr graph, unsigned int layerIndex)
2994*89c4ff92SAndroid Build Coastguard Worker {
2995*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
2996*89c4ff92SAndroid Build Coastguard Worker
2997*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
2998*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
2999*89c4ff92SAndroid Build Coastguard Worker
3000*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3001*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3002*89c4ff92SAndroid Build Coastguard Worker
3003*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = graph->layers()->Get(layerIndex)->layer_as_SliceLayer()->descriptor();
3004*89c4ff92SAndroid Build Coastguard Worker
3005*89c4ff92SAndroid Build Coastguard Worker auto fbBegin = fbDescriptor->begin();
3006*89c4ff92SAndroid Build Coastguard Worker auto fbSize = fbDescriptor->size();
3007*89c4ff92SAndroid Build Coastguard Worker
3008*89c4ff92SAndroid Build Coastguard Worker if (fbBegin->size() != fbSize->size())
3009*89c4ff92SAndroid Build Coastguard Worker {
3010*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("Begin and size descriptors must have the same length {}",
3011*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
3012*89c4ff92SAndroid Build Coastguard Worker }
3013*89c4ff92SAndroid Build Coastguard Worker
3014*89c4ff92SAndroid Build Coastguard Worker armnn::SliceDescriptor descriptor;
3015*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Begin.insert(descriptor.m_Begin.end(), fbBegin->begin(), fbBegin->end());
3016*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Size.insert(descriptor.m_Size.end(), fbSize->begin(), fbSize->end());
3017*89c4ff92SAndroid Build Coastguard Worker
3018*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3019*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddSliceLayer(descriptor, layerName.c_str());
3020*89c4ff92SAndroid Build Coastguard Worker
3021*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3022*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3023*89c4ff92SAndroid Build Coastguard Worker
3024*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3025*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3026*89c4ff92SAndroid Build Coastguard Worker }
3027*89c4ff92SAndroid Build Coastguard Worker
ParseStridedSlice(GraphPtr graph,unsigned int layerIndex)3028*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseStridedSlice(GraphPtr graph, unsigned int layerIndex)
3029*89c4ff92SAndroid Build Coastguard Worker {
3030*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3031*89c4ff92SAndroid Build Coastguard Worker
3032*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3033*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
3034*89c4ff92SAndroid Build Coastguard Worker
3035*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3036*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3037*89c4ff92SAndroid Build Coastguard Worker
3038*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->descriptor();
3039*89c4ff92SAndroid Build Coastguard Worker
3040*89c4ff92SAndroid Build Coastguard Worker auto flatBufferBegin = flatBufferDescriptor->begin();
3041*89c4ff92SAndroid Build Coastguard Worker auto flatBufferEnd = flatBufferDescriptor->end();
3042*89c4ff92SAndroid Build Coastguard Worker auto flatBufferStride = flatBufferDescriptor->stride();
3043*89c4ff92SAndroid Build Coastguard Worker
3044*89c4ff92SAndroid Build Coastguard Worker if (!(flatBufferBegin->size() == flatBufferEnd->size() &&
3045*89c4ff92SAndroid Build Coastguard Worker flatBufferBegin->size() == flatBufferStride->size()))
3046*89c4ff92SAndroid Build Coastguard Worker {
3047*89c4ff92SAndroid Build Coastguard Worker throw ParseException(fmt::format("The size of the begin, end, and stride must be equal {}",
3048*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION().AsString()));
3049*89c4ff92SAndroid Build Coastguard Worker }
3050*89c4ff92SAndroid Build Coastguard Worker
3051*89c4ff92SAndroid Build Coastguard Worker std::vector<int> begin(flatBufferBegin->begin(), flatBufferBegin->end());
3052*89c4ff92SAndroid Build Coastguard Worker std::vector<int> end(flatBufferEnd->begin(), flatBufferEnd->end());
3053*89c4ff92SAndroid Build Coastguard Worker std::vector<int> stride(flatBufferStride->begin(), flatBufferStride->end());
3054*89c4ff92SAndroid Build Coastguard Worker
3055*89c4ff92SAndroid Build Coastguard Worker armnn::StridedSliceDescriptor descriptor(begin, end, stride);
3056*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BeginMask = flatBufferDescriptor->beginMask();
3057*89c4ff92SAndroid Build Coastguard Worker descriptor.m_EndMask = flatBufferDescriptor->endMask();
3058*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ShrinkAxisMask = flatBufferDescriptor->shrinkAxisMask();
3059*89c4ff92SAndroid Build Coastguard Worker descriptor.m_EllipsisMask = flatBufferDescriptor->ellipsisMask();
3060*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NewAxisMask = flatBufferDescriptor->newAxisMask();
3061*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(flatBufferDescriptor->dataLayout());
3062*89c4ff92SAndroid Build Coastguard Worker
3063*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3064*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddStridedSliceLayer(descriptor, layerName.c_str());
3065*89c4ff92SAndroid Build Coastguard Worker
3066*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3067*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3068*89c4ff92SAndroid Build Coastguard Worker
3069*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3070*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3071*89c4ff92SAndroid Build Coastguard Worker }
3072*89c4ff92SAndroid Build Coastguard Worker
ParseSubtraction(GraphPtr graph,unsigned int layerIndex)3073*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSubtraction(GraphPtr graph, unsigned int layerIndex)
3074*89c4ff92SAndroid Build Coastguard Worker {
3075*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3076*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3077*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
3078*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
3079*89c4ff92SAndroid Build Coastguard Worker
3080*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3081*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3082*89c4ff92SAndroid Build Coastguard Worker
3083*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3084*89c4ff92SAndroid Build Coastguard Worker armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Sub);
3085*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str());
3086*89c4ff92SAndroid Build Coastguard Worker
3087*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3088*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3089*89c4ff92SAndroid Build Coastguard Worker
3090*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3091*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3092*89c4ff92SAndroid Build Coastguard Worker }
3093*89c4ff92SAndroid Build Coastguard Worker
ParseGather(GraphPtr graph,unsigned int layerIndex)3094*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseGather(GraphPtr graph, unsigned int layerIndex)
3095*89c4ff92SAndroid Build Coastguard Worker {
3096*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3097*89c4ff92SAndroid Build Coastguard Worker
3098*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3099*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
3100*89c4ff92SAndroid Build Coastguard Worker
3101*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3102*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3103*89c4ff92SAndroid Build Coastguard Worker
3104*89c4ff92SAndroid Build Coastguard Worker armnn::GatherDescriptor descriptor;
3105*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = graph->layers()->Get(layerIndex)->layer_as_GatherLayer()->descriptor()->axis();
3106*89c4ff92SAndroid Build Coastguard Worker
3107*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3108*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddGatherLayer(descriptor, layerName.c_str());
3109*89c4ff92SAndroid Build Coastguard Worker
3110*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3111*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3112*89c4ff92SAndroid Build Coastguard Worker
3113*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3114*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3115*89c4ff92SAndroid Build Coastguard Worker }
3116*89c4ff92SAndroid Build Coastguard Worker
ParseGatherNd(GraphPtr graph,unsigned int layerIndex)3117*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseGatherNd(GraphPtr graph, unsigned int layerIndex)
3118*89c4ff92SAndroid Build Coastguard Worker {
3119*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3120*89c4ff92SAndroid Build Coastguard Worker
3121*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3122*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
3123*89c4ff92SAndroid Build Coastguard Worker
3124*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3125*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3126*89c4ff92SAndroid Build Coastguard Worker
3127*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3128*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddGatherNdLayer(layerName.c_str());
3129*89c4ff92SAndroid Build Coastguard Worker
3130*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3131*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3132*89c4ff92SAndroid Build Coastguard Worker
3133*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3134*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3135*89c4ff92SAndroid Build Coastguard Worker }
3136*89c4ff92SAndroid Build Coastguard Worker
ParseMean(GraphPtr graph,unsigned int layerIndex)3137*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseMean(GraphPtr graph, unsigned int layerIndex)
3138*89c4ff92SAndroid Build Coastguard Worker {
3139*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3140*89c4ff92SAndroid Build Coastguard Worker
3141*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3142*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
3143*89c4ff92SAndroid Build Coastguard Worker
3144*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3145*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3146*89c4ff92SAndroid Build Coastguard Worker
3147*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_MeanLayer()->descriptor();
3148*89c4ff92SAndroid Build Coastguard Worker auto flatBufferAxis = flatBufferDescriptor->axis();
3149*89c4ff92SAndroid Build Coastguard Worker auto flatBufferKeepDims = flatBufferDescriptor->keepDims();
3150*89c4ff92SAndroid Build Coastguard Worker
3151*89c4ff92SAndroid Build Coastguard Worker armnn::MeanDescriptor descriptor;
3152*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = std::vector<unsigned int>(flatBufferAxis->begin(), flatBufferAxis->end());
3153*89c4ff92SAndroid Build Coastguard Worker descriptor.m_KeepDims = flatBufferKeepDims;
3154*89c4ff92SAndroid Build Coastguard Worker
3155*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3156*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddMeanLayer(descriptor, layerName.c_str());
3157*89c4ff92SAndroid Build Coastguard Worker
3158*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3159*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3160*89c4ff92SAndroid Build Coastguard Worker
3161*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3162*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3163*89c4ff92SAndroid Build Coastguard Worker }
3164*89c4ff92SAndroid Build Coastguard Worker
ParseSplitter(GraphPtr graph,unsigned int layerIndex)3165*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSplitter(GraphPtr graph, unsigned int layerIndex)
3166*89c4ff92SAndroid Build Coastguard Worker {
3167*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3168*89c4ff92SAndroid Build Coastguard Worker
3169*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3170*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
3171*89c4ff92SAndroid Build Coastguard Worker
3172*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3173*89c4ff92SAndroid Build Coastguard Worker
3174*89c4ff92SAndroid Build Coastguard Worker auto flatBufferViewsDescriptor = graph->layers()->Get(layerIndex)->layer_as_SplitterLayer()->descriptor();
3175*89c4ff92SAndroid Build Coastguard Worker auto flatBufferViewSizes = flatBufferViewsDescriptor->viewSizes();
3176*89c4ff92SAndroid Build Coastguard Worker auto flatBufferOriginsDescriptor = flatBufferViewsDescriptor->origins();
3177*89c4ff92SAndroid Build Coastguard Worker auto flatBufferViewOrigins = flatBufferOriginsDescriptor->viewOrigins();
3178*89c4ff92SAndroid Build Coastguard Worker uint32_t numViews = flatBufferOriginsDescriptor->numViews();
3179*89c4ff92SAndroid Build Coastguard Worker uint32_t numDimensions = flatBufferOriginsDescriptor->numDimensions();
3180*89c4ff92SAndroid Build Coastguard Worker
3181*89c4ff92SAndroid Build Coastguard Worker // Check numViews and numDimensions corresponds to the ones already serialized ...
3182*89c4ff92SAndroid Build Coastguard Worker // numViews == flatBufferViewSizes.size();
3183*89c4ff92SAndroid Build Coastguard Worker // foreach: numDimensions == flatBufferViewSizes[x].size();
3184*89c4ff92SAndroid Build Coastguard Worker
3185*89c4ff92SAndroid Build Coastguard Worker armnn::ViewsDescriptor viewsDescriptor(numViews, numDimensions);
3186*89c4ff92SAndroid Build Coastguard Worker for(unsigned int vIdx = 0; vIdx < numViews; ++vIdx)
3187*89c4ff92SAndroid Build Coastguard Worker {
3188*89c4ff92SAndroid Build Coastguard Worker for (unsigned int dIdx = 0; dIdx < numDimensions; ++dIdx)
3189*89c4ff92SAndroid Build Coastguard Worker {
3190*89c4ff92SAndroid Build Coastguard Worker viewsDescriptor.SetViewSize(vIdx, dIdx, flatBufferViewSizes->Get(vIdx)->data()->Get(dIdx));
3191*89c4ff92SAndroid Build Coastguard Worker viewsDescriptor.SetViewOriginCoord(vIdx, dIdx, flatBufferViewOrigins->Get(vIdx)->data()->Get(dIdx));
3192*89c4ff92SAndroid Build Coastguard Worker }
3193*89c4ff92SAndroid Build Coastguard Worker }
3194*89c4ff92SAndroid Build Coastguard Worker
3195*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3196*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddSplitterLayer(viewsDescriptor, layerName.c_str());
3197*89c4ff92SAndroid Build Coastguard Worker
3198*89c4ff92SAndroid Build Coastguard Worker // I could have as many outputs as views ...
3199*89c4ff92SAndroid Build Coastguard Worker for(unsigned int vIdx = 0; vIdx < numViews; ++vIdx)
3200*89c4ff92SAndroid Build Coastguard Worker {
3201*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[vIdx]);
3202*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(vIdx).SetTensorInfo(outputTensorInfo);
3203*89c4ff92SAndroid Build Coastguard Worker }
3204*89c4ff92SAndroid Build Coastguard Worker
3205*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3206*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3207*89c4ff92SAndroid Build Coastguard Worker }
3208*89c4ff92SAndroid Build Coastguard Worker
GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor)3209*89c4ff92SAndroid Build Coastguard Worker armnn::LstmDescriptor IDeserializer::DeserializerImpl::GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor)
3210*89c4ff92SAndroid Build Coastguard Worker {
3211*89c4ff92SAndroid Build Coastguard Worker armnn::LstmDescriptor desc;
3212*89c4ff92SAndroid Build Coastguard Worker
3213*89c4ff92SAndroid Build Coastguard Worker desc.m_ActivationFunc = lstmDescriptor->activationFunc();
3214*89c4ff92SAndroid Build Coastguard Worker desc.m_ClippingThresCell = lstmDescriptor->clippingThresCell();
3215*89c4ff92SAndroid Build Coastguard Worker desc.m_ClippingThresProj = lstmDescriptor->clippingThresProj();
3216*89c4ff92SAndroid Build Coastguard Worker desc.m_CifgEnabled = lstmDescriptor->cifgEnabled();
3217*89c4ff92SAndroid Build Coastguard Worker desc.m_PeepholeEnabled = lstmDescriptor->peepholeEnabled();
3218*89c4ff92SAndroid Build Coastguard Worker desc.m_ProjectionEnabled = lstmDescriptor->projectionEnabled();
3219*89c4ff92SAndroid Build Coastguard Worker desc.m_LayerNormEnabled = lstmDescriptor->layerNormEnabled();
3220*89c4ff92SAndroid Build Coastguard Worker
3221*89c4ff92SAndroid Build Coastguard Worker return desc;
3222*89c4ff92SAndroid Build Coastguard Worker }
3223*89c4ff92SAndroid Build Coastguard Worker
ParseLstm(GraphPtr graph,unsigned int layerIndex)3224*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseLstm(GraphPtr graph, unsigned int layerIndex)
3225*89c4ff92SAndroid Build Coastguard Worker {
3226*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3227*89c4ff92SAndroid Build Coastguard Worker
3228*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3229*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 3);
3230*89c4ff92SAndroid Build Coastguard Worker
3231*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3232*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 4);
3233*89c4ff92SAndroid Build Coastguard Worker
3234*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_LstmLayer();
3235*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3236*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = flatBufferLayer->descriptor();
3237*89c4ff92SAndroid Build Coastguard Worker auto flatBufferInputParams = flatBufferLayer->inputParams();
3238*89c4ff92SAndroid Build Coastguard Worker
3239*89c4ff92SAndroid Build Coastguard Worker auto lstmDescriptor = GetLstmDescriptor(flatBufferDescriptor);
3240*89c4ff92SAndroid Build Coastguard Worker
3241*89c4ff92SAndroid Build Coastguard Worker armnn::LstmInputParams lstmInputParams;
3242*89c4ff92SAndroid Build Coastguard Worker
3243*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
3244*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
3245*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
3246*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
3247*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
3248*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
3249*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
3250*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
3251*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
3252*89c4ff92SAndroid Build Coastguard Worker
3253*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
3254*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToCellWeights = &inputToCellWeights;
3255*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
3256*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
3257*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
3258*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
3259*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ForgetGateBias = &forgetGateBias;
3260*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellBias = &cellBias;
3261*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_OutputGateBias = &outputGateBias;
3262*89c4ff92SAndroid Build Coastguard Worker
3263*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToInputWeights;
3264*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToInputWeights;
3265*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToInputWeights;
3266*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputGateBias;
3267*89c4ff92SAndroid Build Coastguard Worker if (!lstmDescriptor.m_CifgEnabled)
3268*89c4ff92SAndroid Build Coastguard Worker {
3269*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
3270*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
3271*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
3272*89c4ff92SAndroid Build Coastguard Worker inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
3273*89c4ff92SAndroid Build Coastguard Worker
3274*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToInputWeights = &inputToInputWeights;
3275*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
3276*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellToInputWeights = &cellToInputWeights;
3277*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputGateBias = &inputGateBias;
3278*89c4ff92SAndroid Build Coastguard Worker }
3279*89c4ff92SAndroid Build Coastguard Worker
3280*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor projectionWeights;
3281*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor projectionBias;
3282*89c4ff92SAndroid Build Coastguard Worker if (lstmDescriptor.m_ProjectionEnabled)
3283*89c4ff92SAndroid Build Coastguard Worker {
3284*89c4ff92SAndroid Build Coastguard Worker projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
3285*89c4ff92SAndroid Build Coastguard Worker projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
3286*89c4ff92SAndroid Build Coastguard Worker
3287*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ProjectionWeights = &projectionWeights;
3288*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ProjectionBias = &projectionBias;
3289*89c4ff92SAndroid Build Coastguard Worker }
3290*89c4ff92SAndroid Build Coastguard Worker
3291*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToForgetWeights;
3292*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToOutputWeights;
3293*89c4ff92SAndroid Build Coastguard Worker if (lstmDescriptor.m_PeepholeEnabled)
3294*89c4ff92SAndroid Build Coastguard Worker {
3295*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
3296*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
3297*89c4ff92SAndroid Build Coastguard Worker
3298*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
3299*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
3300*89c4ff92SAndroid Build Coastguard Worker }
3301*89c4ff92SAndroid Build Coastguard Worker
3302*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputLayerNormWeights;
3303*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetLayerNormWeights;
3304*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellLayerNormWeights;
3305*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputLayerNormWeights;
3306*89c4ff92SAndroid Build Coastguard Worker if (lstmDescriptor.m_LayerNormEnabled)
3307*89c4ff92SAndroid Build Coastguard Worker {
3308*89c4ff92SAndroid Build Coastguard Worker if (!lstmDescriptor.m_CifgEnabled)
3309*89c4ff92SAndroid Build Coastguard Worker {
3310*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
3311*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
3312*89c4ff92SAndroid Build Coastguard Worker }
3313*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
3314*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
3315*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
3316*89c4ff92SAndroid Build Coastguard Worker
3317*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
3318*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
3319*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
3320*89c4ff92SAndroid Build Coastguard Worker }
3321*89c4ff92SAndroid Build Coastguard Worker
3322*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddLstmLayer(lstmDescriptor, lstmInputParams, layerName.c_str());
3323*89c4ff92SAndroid Build Coastguard Worker
3324*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]);
3325*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1);
3326*89c4ff92SAndroid Build Coastguard Worker
3327*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]);
3328*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2);
3329*89c4ff92SAndroid Build Coastguard Worker
3330*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo3 = ToTensorInfo(outputs[2]);
3331*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo3);
3332*89c4ff92SAndroid Build Coastguard Worker
3333*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo4 = ToTensorInfo(outputs[3]);
3334*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo4);
3335*89c4ff92SAndroid Build Coastguard Worker
3336*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3337*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3338*89c4ff92SAndroid Build Coastguard Worker }
3339*89c4ff92SAndroid Build Coastguard Worker
GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptor)3340*89c4ff92SAndroid Build Coastguard Worker armnn::QLstmDescriptor IDeserializer::DeserializerImpl::GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptor)
3341*89c4ff92SAndroid Build Coastguard Worker {
3342*89c4ff92SAndroid Build Coastguard Worker armnn::QLstmDescriptor desc;
3343*89c4ff92SAndroid Build Coastguard Worker
3344*89c4ff92SAndroid Build Coastguard Worker desc.m_CifgEnabled = qLstmDescriptor->cifgEnabled();
3345*89c4ff92SAndroid Build Coastguard Worker desc.m_PeepholeEnabled = qLstmDescriptor->peepholeEnabled();
3346*89c4ff92SAndroid Build Coastguard Worker desc.m_ProjectionEnabled = qLstmDescriptor->projectionEnabled();
3347*89c4ff92SAndroid Build Coastguard Worker desc.m_LayerNormEnabled = qLstmDescriptor->layerNormEnabled();
3348*89c4ff92SAndroid Build Coastguard Worker
3349*89c4ff92SAndroid Build Coastguard Worker desc.m_CellClip = qLstmDescriptor->cellClip();
3350*89c4ff92SAndroid Build Coastguard Worker desc.m_ProjectionClip = qLstmDescriptor->projectionClip();
3351*89c4ff92SAndroid Build Coastguard Worker
3352*89c4ff92SAndroid Build Coastguard Worker desc.m_InputIntermediateScale = qLstmDescriptor->inputIntermediateScale();
3353*89c4ff92SAndroid Build Coastguard Worker desc.m_ForgetIntermediateScale = qLstmDescriptor->forgetIntermediateScale();
3354*89c4ff92SAndroid Build Coastguard Worker desc.m_CellIntermediateScale = qLstmDescriptor->cellIntermediateScale();
3355*89c4ff92SAndroid Build Coastguard Worker desc.m_OutputIntermediateScale = qLstmDescriptor->outputIntermediateScale();
3356*89c4ff92SAndroid Build Coastguard Worker
3357*89c4ff92SAndroid Build Coastguard Worker desc.m_HiddenStateScale = qLstmDescriptor->hiddenStateScale();
3358*89c4ff92SAndroid Build Coastguard Worker desc.m_HiddenStateZeroPoint = qLstmDescriptor->hiddenStateZeroPoint();
3359*89c4ff92SAndroid Build Coastguard Worker
3360*89c4ff92SAndroid Build Coastguard Worker return desc;
3361*89c4ff92SAndroid Build Coastguard Worker }
3362*89c4ff92SAndroid Build Coastguard Worker
ParseQLstm(GraphPtr graph,unsigned int layerIndex)3363*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseQLstm(GraphPtr graph, unsigned int layerIndex)
3364*89c4ff92SAndroid Build Coastguard Worker {
3365*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3366*89c4ff92SAndroid Build Coastguard Worker
3367*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3368*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 3);
3369*89c4ff92SAndroid Build Coastguard Worker
3370*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3371*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 3);
3372*89c4ff92SAndroid Build Coastguard Worker
3373*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QLstmLayer();
3374*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3375*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = flatBufferLayer->descriptor();
3376*89c4ff92SAndroid Build Coastguard Worker auto flatBufferInputParams = flatBufferLayer->inputParams();
3377*89c4ff92SAndroid Build Coastguard Worker
3378*89c4ff92SAndroid Build Coastguard Worker auto qLstmDescriptor = GetQLstmDescriptor(flatBufferDescriptor);
3379*89c4ff92SAndroid Build Coastguard Worker armnn::LstmInputParams qLstmInputParams;
3380*89c4ff92SAndroid Build Coastguard Worker
3381*89c4ff92SAndroid Build Coastguard Worker // Mandatory params
3382*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
3383*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
3384*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
3385*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
3386*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
3387*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
3388*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
3389*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
3390*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
3391*89c4ff92SAndroid Build Coastguard Worker
3392*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
3393*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_InputToCellWeights = &inputToCellWeights;
3394*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
3395*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
3396*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
3397*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
3398*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_ForgetGateBias = &forgetGateBias;
3399*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_CellBias = &cellBias;
3400*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_OutputGateBias = &outputGateBias;
3401*89c4ff92SAndroid Build Coastguard Worker
3402*89c4ff92SAndroid Build Coastguard Worker // Optional CIFG params
3403*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToInputWeights;
3404*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToInputWeights;
3405*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputGateBias;
3406*89c4ff92SAndroid Build Coastguard Worker
3407*89c4ff92SAndroid Build Coastguard Worker if (!qLstmDescriptor.m_CifgEnabled)
3408*89c4ff92SAndroid Build Coastguard Worker {
3409*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
3410*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
3411*89c4ff92SAndroid Build Coastguard Worker inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
3412*89c4ff92SAndroid Build Coastguard Worker
3413*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_InputToInputWeights = &inputToInputWeights;
3414*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
3415*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_InputGateBias = &inputGateBias;
3416*89c4ff92SAndroid Build Coastguard Worker }
3417*89c4ff92SAndroid Build Coastguard Worker
3418*89c4ff92SAndroid Build Coastguard Worker // Optional projection params
3419*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor projectionWeights;
3420*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor projectionBias;
3421*89c4ff92SAndroid Build Coastguard Worker
3422*89c4ff92SAndroid Build Coastguard Worker if (qLstmDescriptor.m_ProjectionEnabled)
3423*89c4ff92SAndroid Build Coastguard Worker {
3424*89c4ff92SAndroid Build Coastguard Worker projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
3425*89c4ff92SAndroid Build Coastguard Worker projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
3426*89c4ff92SAndroid Build Coastguard Worker
3427*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_ProjectionWeights = &projectionWeights;
3428*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_ProjectionBias = &projectionBias;
3429*89c4ff92SAndroid Build Coastguard Worker }
3430*89c4ff92SAndroid Build Coastguard Worker
3431*89c4ff92SAndroid Build Coastguard Worker // Optional peephole params
3432*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToInputWeights;
3433*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToForgetWeights;
3434*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToOutputWeights;
3435*89c4ff92SAndroid Build Coastguard Worker
3436*89c4ff92SAndroid Build Coastguard Worker if (qLstmDescriptor.m_PeepholeEnabled)
3437*89c4ff92SAndroid Build Coastguard Worker {
3438*89c4ff92SAndroid Build Coastguard Worker if (!qLstmDescriptor.m_CifgEnabled)
3439*89c4ff92SAndroid Build Coastguard Worker {
3440*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
3441*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_CellToInputWeights = &cellToInputWeights;
3442*89c4ff92SAndroid Build Coastguard Worker }
3443*89c4ff92SAndroid Build Coastguard Worker
3444*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
3445*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
3446*89c4ff92SAndroid Build Coastguard Worker
3447*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
3448*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
3449*89c4ff92SAndroid Build Coastguard Worker }
3450*89c4ff92SAndroid Build Coastguard Worker
3451*89c4ff92SAndroid Build Coastguard Worker // Optional layer norm params
3452*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputLayerNormWeights;
3453*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetLayerNormWeights;
3454*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellLayerNormWeights;
3455*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputLayerNormWeights;
3456*89c4ff92SAndroid Build Coastguard Worker
3457*89c4ff92SAndroid Build Coastguard Worker if (qLstmDescriptor.m_LayerNormEnabled)
3458*89c4ff92SAndroid Build Coastguard Worker {
3459*89c4ff92SAndroid Build Coastguard Worker if (!qLstmDescriptor.m_CifgEnabled)
3460*89c4ff92SAndroid Build Coastguard Worker {
3461*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
3462*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
3463*89c4ff92SAndroid Build Coastguard Worker }
3464*89c4ff92SAndroid Build Coastguard Worker
3465*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
3466*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
3467*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
3468*89c4ff92SAndroid Build Coastguard Worker
3469*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
3470*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
3471*89c4ff92SAndroid Build Coastguard Worker qLstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
3472*89c4ff92SAndroid Build Coastguard Worker }
3473*89c4ff92SAndroid Build Coastguard Worker
3474*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddQLstmLayer(qLstmDescriptor, qLstmInputParams, layerName.c_str());
3475*89c4ff92SAndroid Build Coastguard Worker
3476*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputStateOutInfo = ToTensorInfo(outputs[0]);
3477*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputStateOutInfo);
3478*89c4ff92SAndroid Build Coastguard Worker
3479*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo cellStateOutInfo = ToTensorInfo(outputs[1]);
3480*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(1).SetTensorInfo(cellStateOutInfo);
3481*89c4ff92SAndroid Build Coastguard Worker
3482*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo = ToTensorInfo(outputs[2]);
3483*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(2).SetTensorInfo(outputInfo);
3484*89c4ff92SAndroid Build Coastguard Worker
3485*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3486*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3487*89c4ff92SAndroid Build Coastguard Worker }
3488*89c4ff92SAndroid Build Coastguard Worker
ParseQuantizedLstm(GraphPtr graph,unsigned int layerIndex)3489*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex)
3490*89c4ff92SAndroid Build Coastguard Worker {
3491*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3492*89c4ff92SAndroid Build Coastguard Worker
3493*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3494*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 3);
3495*89c4ff92SAndroid Build Coastguard Worker
3496*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3497*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 2);
3498*89c4ff92SAndroid Build Coastguard Worker
3499*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_QuantizedLstmLayer();
3500*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3501*89c4ff92SAndroid Build Coastguard Worker auto flatBufferInputParams = flatBufferLayer->inputParams();
3502*89c4ff92SAndroid Build Coastguard Worker
3503*89c4ff92SAndroid Build Coastguard Worker armnn::QuantizedLstmInputParams lstmInputParams;
3504*89c4ff92SAndroid Build Coastguard Worker
3505*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
3506*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
3507*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
3508*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
3509*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
3510*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
3511*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
3512*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
3513*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
3514*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
3515*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
3516*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
3517*89c4ff92SAndroid Build Coastguard Worker
3518*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToInputWeights = &inputToInputWeights;
3519*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
3520*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToCellWeights = &inputToCellWeights;
3521*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
3522*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
3523*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
3524*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
3525*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
3526*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputGateBias = &inputGateBias;
3527*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ForgetGateBias = &forgetGateBias;
3528*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellBias = &cellBias;
3529*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_OutputGateBias = &outputGateBias;
3530*89c4ff92SAndroid Build Coastguard Worker
3531*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddQuantizedLstmLayer(lstmInputParams, layerName.c_str());
3532*89c4ff92SAndroid Build Coastguard Worker
3533*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[0]);
3534*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo1);
3535*89c4ff92SAndroid Build Coastguard Worker
3536*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[1]);
3537*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo2);
3538*89c4ff92SAndroid Build Coastguard Worker
3539*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3540*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3541*89c4ff92SAndroid Build Coastguard Worker }
3542*89c4ff92SAndroid Build Coastguard Worker
ParseDequantize(GraphPtr graph,unsigned int layerIndex)3543*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseDequantize(GraphPtr graph, unsigned int layerIndex)
3544*89c4ff92SAndroid Build Coastguard Worker {
3545*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3546*89c4ff92SAndroid Build Coastguard Worker
3547*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3548*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
3549*89c4ff92SAndroid Build Coastguard Worker
3550*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3551*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3552*89c4ff92SAndroid Build Coastguard Worker
3553*89c4ff92SAndroid Build Coastguard Worker const std::string layerName = GetLayerName(graph, layerIndex);
3554*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddDequantizeLayer(layerName.c_str());
3555*89c4ff92SAndroid Build Coastguard Worker
3556*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3557*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3558*89c4ff92SAndroid Build Coastguard Worker
3559*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3560*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3561*89c4ff92SAndroid Build Coastguard Worker }
3562*89c4ff92SAndroid Build Coastguard Worker
ParseMerge(GraphPtr graph,unsigned int layerIndex)3563*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseMerge(GraphPtr graph, unsigned int layerIndex)
3564*89c4ff92SAndroid Build Coastguard Worker {
3565*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3566*89c4ff92SAndroid Build Coastguard Worker
3567*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
3568*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
3569*89c4ff92SAndroid Build Coastguard Worker
3570*89c4ff92SAndroid Build Coastguard Worker TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
3571*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3572*89c4ff92SAndroid Build Coastguard Worker
3573*89c4ff92SAndroid Build Coastguard Worker const std::string layerName = GetLayerName(graph, layerIndex);
3574*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddMergeLayer(layerName.c_str());
3575*89c4ff92SAndroid Build Coastguard Worker
3576*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3577*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3578*89c4ff92SAndroid Build Coastguard Worker
3579*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3580*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3581*89c4ff92SAndroid Build Coastguard Worker }
3582*89c4ff92SAndroid Build Coastguard Worker
ParseSwitch(GraphPtr graph,unsigned int layerIndex)3583*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseSwitch(GraphPtr graph, unsigned int layerIndex)
3584*89c4ff92SAndroid Build Coastguard Worker {
3585*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3586*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3587*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
3588*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
3589*89c4ff92SAndroid Build Coastguard Worker
3590*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3591*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 2);
3592*89c4ff92SAndroid Build Coastguard Worker
3593*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3594*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddSwitchLayer(layerName.c_str());
3595*89c4ff92SAndroid Build Coastguard Worker
3596*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo output0TensorInfo = ToTensorInfo(outputs[0]);
3597*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(output0TensorInfo);
3598*89c4ff92SAndroid Build Coastguard Worker
3599*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo output1TensorInfo = ToTensorInfo(outputs[1]);
3600*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(1).SetTensorInfo(output1TensorInfo);
3601*89c4ff92SAndroid Build Coastguard Worker
3602*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3603*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3604*89c4ff92SAndroid Build Coastguard Worker }
3605*89c4ff92SAndroid Build Coastguard Worker
ParsePrelu(GraphPtr graph,unsigned int layerIndex)3606*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParsePrelu(GraphPtr graph, unsigned int layerIndex)
3607*89c4ff92SAndroid Build Coastguard Worker {
3608*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3609*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3610*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION();
3611*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 2);
3612*89c4ff92SAndroid Build Coastguard Worker
3613*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3614*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3615*89c4ff92SAndroid Build Coastguard Worker
3616*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3617*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddPreluLayer(layerName.c_str());
3618*89c4ff92SAndroid Build Coastguard Worker
3619*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3620*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3621*89c4ff92SAndroid Build Coastguard Worker
3622*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3623*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3624*89c4ff92SAndroid Build Coastguard Worker }
3625*89c4ff92SAndroid Build Coastguard Worker
ParseTranspose(GraphPtr graph,unsigned int layerIndex)3626*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseTranspose(GraphPtr graph, unsigned int layerIndex)
3627*89c4ff92SAndroid Build Coastguard Worker {
3628*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3629*89c4ff92SAndroid Build Coastguard Worker
3630*89c4ff92SAndroid Build Coastguard Worker auto dimsMapping = graph->layers()->Get(layerIndex)->layer_as_TransposeLayer()->descriptor()->dimMappings();
3631*89c4ff92SAndroid Build Coastguard Worker
3632*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3633*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
3634*89c4ff92SAndroid Build Coastguard Worker
3635*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3636*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3637*89c4ff92SAndroid Build Coastguard Worker auto outputInfo = ToTensorInfo(outputs[0]);
3638*89c4ff92SAndroid Build Coastguard Worker
3639*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3640*89c4ff92SAndroid Build Coastguard Worker const armnn::TransposeDescriptor descriptor(armnn::PermutationVector(dimsMapping->data(), dimsMapping->size()));
3641*89c4ff92SAndroid Build Coastguard Worker
3642*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddTransposeLayer(descriptor, layerName.c_str());
3643*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
3644*89c4ff92SAndroid Build Coastguard Worker
3645*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3646*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3647*89c4ff92SAndroid Build Coastguard Worker }
3648*89c4ff92SAndroid Build Coastguard Worker
ParseTransposeConvolution2d(GraphPtr graph,unsigned int layerIndex)3649*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex)
3650*89c4ff92SAndroid Build Coastguard Worker {
3651*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3652*89c4ff92SAndroid Build Coastguard Worker
3653*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3654*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 1);
3655*89c4ff92SAndroid Build Coastguard Worker
3656*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3657*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3658*89c4ff92SAndroid Build Coastguard Worker
3659*89c4ff92SAndroid Build Coastguard Worker auto serializerLayer = graph->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer();
3660*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3661*89c4ff92SAndroid Build Coastguard Worker auto serializerDescriptor = serializerLayer->descriptor();
3662*89c4ff92SAndroid Build Coastguard Worker
3663*89c4ff92SAndroid Build Coastguard Worker armnn::TransposeConvolution2dDescriptor descriptor;
3664*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft = serializerDescriptor->padLeft();
3665*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight = serializerDescriptor->padRight();
3666*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop = serializerDescriptor->padTop();
3667*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom = serializerDescriptor->padBottom();
3668*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = serializerDescriptor->strideX();
3669*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = serializerDescriptor->strideY();;
3670*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = serializerDescriptor->biasEnabled();;
3671*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = ToDataLayout(serializerDescriptor->dataLayout());
3672*89c4ff92SAndroid Build Coastguard Worker
3673*89c4ff92SAndroid Build Coastguard Worker // weights & biases
3674*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weights = ToConstTensor(serializerLayer->weights());
3675*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::ConstTensor> optionalBiases;
3676*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_BiasEnabled)
3677*89c4ff92SAndroid Build Coastguard Worker {
3678*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor biases = ToConstTensor(serializerLayer->biases());
3679*89c4ff92SAndroid Build Coastguard Worker optionalBiases = armnn::MakeOptional<armnn::ConstTensor>(biases);
3680*89c4ff92SAndroid Build Coastguard Worker }
3681*89c4ff92SAndroid Build Coastguard Worker
3682*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddTransposeConvolution2dLayer(descriptor,
3683*89c4ff92SAndroid Build Coastguard Worker weights,
3684*89c4ff92SAndroid Build Coastguard Worker optionalBiases,
3685*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
3686*89c4ff92SAndroid Build Coastguard Worker
3687*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3688*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3689*89c4ff92SAndroid Build Coastguard Worker
3690*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3691*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3692*89c4ff92SAndroid Build Coastguard Worker }
3693*89c4ff92SAndroid Build Coastguard Worker
ParseStack(GraphPtr graph,unsigned int layerIndex)3694*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseStack(GraphPtr graph, unsigned int layerIndex)
3695*89c4ff92SAndroid Build Coastguard Worker {
3696*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3697*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3698*89c4ff92SAndroid Build Coastguard Worker
3699*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3700*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 1);
3701*89c4ff92SAndroid Build Coastguard Worker
3702*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = graph->layers()->Get(layerIndex)->layer_as_StackLayer()->descriptor();
3703*89c4ff92SAndroid Build Coastguard Worker unsigned int axis = flatBufferDescriptor->axis();
3704*89c4ff92SAndroid Build Coastguard Worker unsigned int numInputs = flatBufferDescriptor->numInputs();
3705*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), numInputs);
3706*89c4ff92SAndroid Build Coastguard Worker
3707*89c4ff92SAndroid Build Coastguard Worker auto flatBufferInputShape = flatBufferDescriptor->inputShape();
3708*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> vectorInputShape(flatBufferInputShape->begin(),
3709*89c4ff92SAndroid Build Coastguard Worker flatBufferInputShape->begin() + flatBufferInputShape->size());
3710*89c4ff92SAndroid Build Coastguard Worker
3711*89c4ff92SAndroid Build Coastguard Worker TensorShape inputShape(static_cast<unsigned int>(vectorInputShape.size()), vectorInputShape.data());
3712*89c4ff92SAndroid Build Coastguard Worker armnn::StackDescriptor descriptor(axis, numInputs, inputShape);
3713*89c4ff92SAndroid Build Coastguard Worker
3714*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i<inputs.size(); ++i)
3715*89c4ff92SAndroid Build Coastguard Worker {
3716*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape = ToTensorInfo(inputs[i]).GetShape();
3717*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_InputShape != inputShape)
3718*89c4ff92SAndroid Build Coastguard Worker {
3719*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
3720*89c4ff92SAndroid Build Coastguard Worker ss << "Shape of input "
3721*89c4ff92SAndroid Build Coastguard Worker << i
3722*89c4ff92SAndroid Build Coastguard Worker << " "
3723*89c4ff92SAndroid Build Coastguard Worker << inputShape
3724*89c4ff92SAndroid Build Coastguard Worker << " does not equal defined input shape "
3725*89c4ff92SAndroid Build Coastguard Worker << descriptor.m_InputShape
3726*89c4ff92SAndroid Build Coastguard Worker << ": "
3727*89c4ff92SAndroid Build Coastguard Worker << CHECK_LOCATION().AsString();
3728*89c4ff92SAndroid Build Coastguard Worker throw ParseException(ss.str());
3729*89c4ff92SAndroid Build Coastguard Worker }
3730*89c4ff92SAndroid Build Coastguard Worker }
3731*89c4ff92SAndroid Build Coastguard Worker
3732*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3733*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddStackLayer(descriptor, layerName.c_str());
3734*89c4ff92SAndroid Build Coastguard Worker
3735*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
3736*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3737*89c4ff92SAndroid Build Coastguard Worker
3738*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3739*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3740*89c4ff92SAndroid Build Coastguard Worker }
3741*89c4ff92SAndroid Build Coastguard Worker
ParseStandIn(GraphPtr graph,unsigned int layerIndex)3742*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseStandIn(GraphPtr graph, unsigned int layerIndex)
3743*89c4ff92SAndroid Build Coastguard Worker {
3744*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3745*89c4ff92SAndroid Build Coastguard Worker
3746*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3747*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3748*89c4ff92SAndroid Build Coastguard Worker
3749*89c4ff92SAndroid Build Coastguard Worker auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_StandInLayer();
3750*89c4ff92SAndroid Build Coastguard Worker auto fbDescriptor = fbLayer->descriptor();
3751*89c4ff92SAndroid Build Coastguard Worker
3752*89c4ff92SAndroid Build Coastguard Worker armnn::StandInDescriptor descriptor;
3753*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NumInputs = fbDescriptor->numInputs();
3754*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NumOutputs = fbDescriptor->numOutputs();
3755*89c4ff92SAndroid Build Coastguard Worker
3756*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), descriptor.m_NumInputs);
3757*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), descriptor.m_NumOutputs);
3758*89c4ff92SAndroid Build Coastguard Worker
3759*89c4ff92SAndroid Build Coastguard Worker const std::string layerName = GetLayerName(graph, layerIndex);
3760*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = m_Network->AddStandInLayer(descriptor, layerName.c_str());
3761*89c4ff92SAndroid Build Coastguard Worker
3762*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0u; i < descriptor.m_NumOutputs; ++i)
3763*89c4ff92SAndroid Build Coastguard Worker {
3764*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo = ToTensorInfo(outputs[i]);
3765*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(i).SetTensorInfo(outputInfo);
3766*89c4ff92SAndroid Build Coastguard Worker }
3767*89c4ff92SAndroid Build Coastguard Worker
3768*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3769*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3770*89c4ff92SAndroid Build Coastguard Worker }
3771*89c4ff92SAndroid Build Coastguard Worker
GetUnidirectionalSequenceLstmDescriptor(UnidirectionalSequenceLstmDescriptorPtr descriptor)3772*89c4ff92SAndroid Build Coastguard Worker armnn::UnidirectionalSequenceLstmDescriptor IDeserializer::DeserializerImpl::GetUnidirectionalSequenceLstmDescriptor(
3773*89c4ff92SAndroid Build Coastguard Worker UnidirectionalSequenceLstmDescriptorPtr descriptor)
3774*89c4ff92SAndroid Build Coastguard Worker {
3775*89c4ff92SAndroid Build Coastguard Worker armnn::UnidirectionalSequenceLstmDescriptor desc;
3776*89c4ff92SAndroid Build Coastguard Worker
3777*89c4ff92SAndroid Build Coastguard Worker desc.m_ActivationFunc = descriptor->activationFunc();
3778*89c4ff92SAndroid Build Coastguard Worker desc.m_ClippingThresCell = descriptor->clippingThresCell();
3779*89c4ff92SAndroid Build Coastguard Worker desc.m_ClippingThresProj = descriptor->clippingThresProj();
3780*89c4ff92SAndroid Build Coastguard Worker desc.m_CifgEnabled = descriptor->cifgEnabled();
3781*89c4ff92SAndroid Build Coastguard Worker desc.m_PeepholeEnabled = descriptor->peepholeEnabled();
3782*89c4ff92SAndroid Build Coastguard Worker desc.m_ProjectionEnabled = descriptor->projectionEnabled();
3783*89c4ff92SAndroid Build Coastguard Worker desc.m_LayerNormEnabled = descriptor->layerNormEnabled();
3784*89c4ff92SAndroid Build Coastguard Worker desc.m_TimeMajor = descriptor->timeMajor();
3785*89c4ff92SAndroid Build Coastguard Worker
3786*89c4ff92SAndroid Build Coastguard Worker return desc;
3787*89c4ff92SAndroid Build Coastguard Worker }
3788*89c4ff92SAndroid Build Coastguard Worker
ParseUnidirectionalSequenceLstm(GraphPtr graph,unsigned int layerIndex)3789*89c4ff92SAndroid Build Coastguard Worker void IDeserializer::DeserializerImpl::ParseUnidirectionalSequenceLstm(GraphPtr graph, unsigned int layerIndex)
3790*89c4ff92SAndroid Build Coastguard Worker {
3791*89c4ff92SAndroid Build Coastguard Worker CHECK_LAYERS(graph, 0, layerIndex);
3792*89c4ff92SAndroid Build Coastguard Worker
3793*89c4ff92SAndroid Build Coastguard Worker auto inputs = GetInputs(graph, layerIndex);
3794*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(inputs.size(), 3);
3795*89c4ff92SAndroid Build Coastguard Worker
3796*89c4ff92SAndroid Build Coastguard Worker auto outputs = GetOutputs(graph, layerIndex);
3797*89c4ff92SAndroid Build Coastguard Worker CHECK_VALID_SIZE(outputs.size(), 3);
3798*89c4ff92SAndroid Build Coastguard Worker
3799*89c4ff92SAndroid Build Coastguard Worker auto flatBufferLayer = graph->layers()->Get(layerIndex)->layer_as_UnidirectionalSequenceLstmLayer();
3800*89c4ff92SAndroid Build Coastguard Worker auto layerName = GetLayerName(graph, layerIndex);
3801*89c4ff92SAndroid Build Coastguard Worker auto flatBufferDescriptor = flatBufferLayer->descriptor();
3802*89c4ff92SAndroid Build Coastguard Worker auto flatBufferInputParams = flatBufferLayer->inputParams();
3803*89c4ff92SAndroid Build Coastguard Worker
3804*89c4ff92SAndroid Build Coastguard Worker auto descriptor = GetUnidirectionalSequenceLstmDescriptor(flatBufferDescriptor);
3805*89c4ff92SAndroid Build Coastguard Worker
3806*89c4ff92SAndroid Build Coastguard Worker armnn::LstmInputParams lstmInputParams;
3807*89c4ff92SAndroid Build Coastguard Worker
3808*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToForgetWeights = ToConstTensor(flatBufferInputParams->inputToForgetWeights());
3809*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToCellWeights = ToConstTensor(flatBufferInputParams->inputToCellWeights());
3810*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToOutputWeights = ToConstTensor(flatBufferInputParams->inputToOutputWeights());
3811*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToForgetWeights = ToConstTensor(flatBufferInputParams->recurrentToForgetWeights());
3812*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToCellWeights = ToConstTensor(flatBufferInputParams->recurrentToCellWeights());
3813*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToOutputWeights = ToConstTensor(flatBufferInputParams->recurrentToOutputWeights());
3814*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetGateBias = ToConstTensor(flatBufferInputParams->forgetGateBias());
3815*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellBias = ToConstTensor(flatBufferInputParams->cellBias());
3816*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputGateBias = ToConstTensor(flatBufferInputParams->outputGateBias());
3817*89c4ff92SAndroid Build Coastguard Worker
3818*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToForgetWeights = &inputToForgetWeights;
3819*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToCellWeights = &inputToCellWeights;
3820*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToOutputWeights = &inputToOutputWeights;
3821*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
3822*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToCellWeights = &recurrentToCellWeights;
3823*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
3824*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ForgetGateBias = &forgetGateBias;
3825*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellBias = &cellBias;
3826*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_OutputGateBias = &outputGateBias;
3827*89c4ff92SAndroid Build Coastguard Worker
3828*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputToInputWeights;
3829*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor recurrentToInputWeights;
3830*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToInputWeights;
3831*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputGateBias;
3832*89c4ff92SAndroid Build Coastguard Worker if (!descriptor.m_CifgEnabled)
3833*89c4ff92SAndroid Build Coastguard Worker {
3834*89c4ff92SAndroid Build Coastguard Worker inputToInputWeights = ToConstTensor(flatBufferInputParams->inputToInputWeights());
3835*89c4ff92SAndroid Build Coastguard Worker recurrentToInputWeights = ToConstTensor(flatBufferInputParams->recurrentToInputWeights());
3836*89c4ff92SAndroid Build Coastguard Worker inputGateBias = ToConstTensor(flatBufferInputParams->inputGateBias());
3837*89c4ff92SAndroid Build Coastguard Worker
3838*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputToInputWeights = &inputToInputWeights;
3839*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_RecurrentToInputWeights = &recurrentToInputWeights;
3840*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputGateBias = &inputGateBias;
3841*89c4ff92SAndroid Build Coastguard Worker
3842*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_PeepholeEnabled)
3843*89c4ff92SAndroid Build Coastguard Worker {
3844*89c4ff92SAndroid Build Coastguard Worker cellToInputWeights = ToConstTensor(flatBufferInputParams->cellToInputWeights());
3845*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellToInputWeights = &cellToInputWeights;
3846*89c4ff92SAndroid Build Coastguard Worker }
3847*89c4ff92SAndroid Build Coastguard Worker }
3848*89c4ff92SAndroid Build Coastguard Worker
3849*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor projectionWeights;
3850*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor projectionBias;
3851*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_ProjectionEnabled)
3852*89c4ff92SAndroid Build Coastguard Worker {
3853*89c4ff92SAndroid Build Coastguard Worker projectionWeights = ToConstTensor(flatBufferInputParams->projectionWeights());
3854*89c4ff92SAndroid Build Coastguard Worker projectionBias = ToConstTensor(flatBufferInputParams->projectionBias());
3855*89c4ff92SAndroid Build Coastguard Worker
3856*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ProjectionWeights = &projectionWeights;
3857*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ProjectionBias = &projectionBias;
3858*89c4ff92SAndroid Build Coastguard Worker }
3859*89c4ff92SAndroid Build Coastguard Worker
3860*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToForgetWeights;
3861*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellToOutputWeights;
3862*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_PeepholeEnabled)
3863*89c4ff92SAndroid Build Coastguard Worker {
3864*89c4ff92SAndroid Build Coastguard Worker cellToForgetWeights = ToConstTensor(flatBufferInputParams->cellToForgetWeights());
3865*89c4ff92SAndroid Build Coastguard Worker cellToOutputWeights = ToConstTensor(flatBufferInputParams->cellToOutputWeights());
3866*89c4ff92SAndroid Build Coastguard Worker
3867*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellToForgetWeights = &cellToForgetWeights;
3868*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellToOutputWeights = &cellToOutputWeights;
3869*89c4ff92SAndroid Build Coastguard Worker }
3870*89c4ff92SAndroid Build Coastguard Worker
3871*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputLayerNormWeights;
3872*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor forgetLayerNormWeights;
3873*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor cellLayerNormWeights;
3874*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor outputLayerNormWeights;
3875*89c4ff92SAndroid Build Coastguard Worker if (descriptor.m_LayerNormEnabled)
3876*89c4ff92SAndroid Build Coastguard Worker {
3877*89c4ff92SAndroid Build Coastguard Worker if (!descriptor.m_CifgEnabled)
3878*89c4ff92SAndroid Build Coastguard Worker {
3879*89c4ff92SAndroid Build Coastguard Worker inputLayerNormWeights = ToConstTensor(flatBufferInputParams->inputLayerNormWeights());
3880*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_InputLayerNormWeights = &inputLayerNormWeights;
3881*89c4ff92SAndroid Build Coastguard Worker }
3882*89c4ff92SAndroid Build Coastguard Worker forgetLayerNormWeights = ToConstTensor(flatBufferInputParams->forgetLayerNormWeights());
3883*89c4ff92SAndroid Build Coastguard Worker cellLayerNormWeights = ToConstTensor(flatBufferInputParams->cellLayerNormWeights());
3884*89c4ff92SAndroid Build Coastguard Worker outputLayerNormWeights = ToConstTensor(flatBufferInputParams->outputLayerNormWeights());
3885*89c4ff92SAndroid Build Coastguard Worker
3886*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_ForgetLayerNormWeights = &forgetLayerNormWeights;
3887*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_CellLayerNormWeights = &cellLayerNormWeights;
3888*89c4ff92SAndroid Build Coastguard Worker lstmInputParams.m_OutputLayerNormWeights = &outputLayerNormWeights;
3889*89c4ff92SAndroid Build Coastguard Worker }
3890*89c4ff92SAndroid Build Coastguard Worker
3891*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* layer = m_Network->AddUnidirectionalSequenceLstmLayer(descriptor,
3892*89c4ff92SAndroid Build Coastguard Worker lstmInputParams,
3893*89c4ff92SAndroid Build Coastguard Worker layerName.c_str());
3894*89c4ff92SAndroid Build Coastguard Worker
3895*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo0 = ToTensorInfo(outputs[0]);
3896*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo0);
3897*89c4ff92SAndroid Build Coastguard Worker
3898*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo1 = ToTensorInfo(outputs[1]);
3899*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(1).SetTensorInfo(outputTensorInfo1);
3900*89c4ff92SAndroid Build Coastguard Worker
3901*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo2 = ToTensorInfo(outputs[2]);
3902*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo2);
3903*89c4ff92SAndroid Build Coastguard Worker
3904*89c4ff92SAndroid Build Coastguard Worker RegisterInputSlots(graph, layerIndex, layer);
3905*89c4ff92SAndroid Build Coastguard Worker RegisterOutputSlots(graph, layerIndex, layer);
3906*89c4ff92SAndroid Build Coastguard Worker }
3907*89c4ff92SAndroid Build Coastguard Worker
3908*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDeserializer
3909