xref: /aosp_15_r20/external/android-nn-driver/ConversionUtils_1_2.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "Utils.hpp"
9*3e777be0SXin Li 
10*3e777be0SXin Li #include "ConversionUtils.hpp"
11*3e777be0SXin Li 
12*3e777be0SXin Li #include <armnn/utility/NumericCast.hpp>
13*3e777be0SXin Li #include <armnnUtils/TensorUtils.hpp>
14*3e777be0SXin Li 
15*3e777be0SXin Li #include <half/half.hpp>
16*3e777be0SXin Li 
17*3e777be0SXin Li using Half = half_float::half;
18*3e777be0SXin Li 
19*3e777be0SXin Li namespace armnn_driver
20*3e777be0SXin Li {
21*3e777be0SXin Li 
22*3e777be0SXin Li using namespace armnn;
23*3e777be0SXin Li using namespace android::nn;
24*3e777be0SXin Li 
25*3e777be0SXin Li template<typename HalPolicy,
26*3e777be0SXin Li         typename HalOperation = typename HalPolicy::Operation,
27*3e777be0SXin Li         typename HalModel     = typename HalPolicy::Model>
IsWeightsValid(const HalOperation & operation,uint32_t inputIndex,const HalModel & model)28*3e777be0SXin Li bool IsWeightsValid(const HalOperation& operation,
29*3e777be0SXin Li                     uint32_t inputIndex,
30*3e777be0SXin Li                     const HalModel& model)
31*3e777be0SXin Li {
32*3e777be0SXin Li     using HalOperand         = typename HalPolicy::Operand;
33*3e777be0SXin Li     using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
34*3e777be0SXin Li     const HalOperand* operand = GetInputOperand<HalPolicy>(operation, inputIndex, model);
35*3e777be0SXin Li     if (!operand)
36*3e777be0SXin Li     {
37*3e777be0SXin Li         Fail("%s: failed to get input operand %i", __func__, inputIndex);
38*3e777be0SXin Li         return false;
39*3e777be0SXin Li     }
40*3e777be0SXin Li 
41*3e777be0SXin Li     if (operand->lifetime    != HalOperandLifeTime::CONSTANT_COPY
42*3e777be0SXin Li         && operand->lifetime != HalOperandLifeTime::CONSTANT_REFERENCE
43*3e777be0SXin Li         && operand->lifetime != HalOperandLifeTime::NO_VALUE)
44*3e777be0SXin Li     {
45*3e777be0SXin Li         return false;
46*3e777be0SXin Li     }
47*3e777be0SXin Li     return true;
48*3e777be0SXin Li }
49*3e777be0SXin Li 
50*3e777be0SXin Li template<typename HalPolicy,
51*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
52*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
IsQSymmDequantizeForWeights(const HalOperation & operation,const HalModel & model)53*3e777be0SXin Li bool IsQSymmDequantizeForWeights(const HalOperation& operation, const HalModel& model)
54*3e777be0SXin Li {
55*3e777be0SXin Li     using HalOperand       = typename HalPolicy::Operand;
56*3e777be0SXin Li     using HalOperationType = typename HalPolicy::OperationType;
57*3e777be0SXin Li 
58*3e777be0SXin Li     const HalOperand* operand = GetInputOperand<HalPolicy>(operation, 0, model);
59*3e777be0SXin Li     if (!operand)
60*3e777be0SXin Li     {
61*3e777be0SXin Li         return false;
62*3e777be0SXin Li     }
63*3e777be0SXin Li 
64*3e777be0SXin Li     if(!IsQSymm8(*operand))
65*3e777be0SXin Li     {
66*3e777be0SXin Li         // Only QSymm8 weights are dequantized on the fly by the driver
67*3e777be0SXin Li         return false;
68*3e777be0SXin Li     }
69*3e777be0SXin Li 
70*3e777be0SXin Li     if (!IsOperandConstant<HalPolicy>(*operand))
71*3e777be0SXin Li     {
72*3e777be0SXin Li         // Non-const input is not accepted for weights
73*3e777be0SXin Li         return false;
74*3e777be0SXin Li     }
75*3e777be0SXin Li 
76*3e777be0SXin Li     // Iterate through all the operations and find the operation feeding from the Dequantize output
77*3e777be0SXin Li     const size_t outputIndex = operation.outputs[0];
78*3e777be0SXin Li     for (uint32_t operationIdx = 0; operationIdx < getMainModel(model).operations.size(); ++operationIdx)
79*3e777be0SXin Li     {
80*3e777be0SXin Li         const auto& operationIt = getMainModel(model).operations[operationIdx];
81*3e777be0SXin Li         switch (operationIt.type)
82*3e777be0SXin Li         {
83*3e777be0SXin Li             case HalOperationType::FULLY_CONNECTED:
84*3e777be0SXin Li                 if (outputIndex == operationIt.inputs[1]) // Weights are bound to slot 1
85*3e777be0SXin Li                 {
86*3e777be0SXin Li                     // If the output is going into the FC weights return true
87*3e777be0SXin Li                     return true;
88*3e777be0SXin Li                 }
89*3e777be0SXin Li                 break;
90*3e777be0SXin Li             case HalOperationType::LSTM:
91*3e777be0SXin Li                 for (size_t k = 0; k < operationIt.inputs.size(); ++k)
92*3e777be0SXin Li                 {
93*3e777be0SXin Li                     if (outputIndex == operationIt.inputs[k])
94*3e777be0SXin Li                     {
95*3e777be0SXin Li                         // If the output is going into the LSTM weights return true
96*3e777be0SXin Li                         return true;
97*3e777be0SXin Li                     }
98*3e777be0SXin Li                 }
99*3e777be0SXin Li                 break;
100*3e777be0SXin Li             default:
101*3e777be0SXin Li                 break;
102*3e777be0SXin Li         }
103*3e777be0SXin Li     }
104*3e777be0SXin Li 
105*3e777be0SXin Li     return false;
106*3e777be0SXin Li }
107*3e777be0SXin Li 
108*3e777be0SXin Li template<typename HalPolicy,
109*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
110*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
SetupAndTrackLayerOutputSlotAndOverrideTensorInfo(const HalOperation & operation,uint32_t operationOutputIndex,armnn::IConnectableLayer & layer,uint32_t layerOutputIndex,const HalModel & model,ConversionData & data,const armnn::TensorInfo tensor_info)111*3e777be0SXin Li bool SetupAndTrackLayerOutputSlotAndOverrideTensorInfo(const HalOperation& operation,
112*3e777be0SXin Li                                                        uint32_t operationOutputIndex,
113*3e777be0SXin Li                                                        armnn::IConnectableLayer& layer,
114*3e777be0SXin Li                                                        uint32_t layerOutputIndex,
115*3e777be0SXin Li                                                        const HalModel& model,
116*3e777be0SXin Li                                                        ConversionData& data,
117*3e777be0SXin Li                                                        const armnn::TensorInfo tensor_info)
118*3e777be0SXin Li {
119*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
120*3e777be0SXin Li 
121*3e777be0SXin Li     const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, operationOutputIndex, model);
122*3e777be0SXin Li     if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots()))
123*3e777be0SXin Li     {
124*3e777be0SXin Li         return false;
125*3e777be0SXin Li     }
126*3e777be0SXin Li 
127*3e777be0SXin Li     armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
128*3e777be0SXin Li 
129*3e777be0SXin Li     const uint32_t operandIndex = operation.outputs[operationOutputIndex];
130*3e777be0SXin Li     data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
131*3e777be0SXin Li 
132*3e777be0SXin Li     outputSlot.SetTensorInfo(tensor_info);
133*3e777be0SXin Li 
134*3e777be0SXin Li     return true;
135*3e777be0SXin Li }
136*3e777be0SXin Li 
137*3e777be0SXin Li template<typename HalPolicy,
138*3e777be0SXin Li     typename HalOperation = typename HalPolicy::Operation,
139*3e777be0SXin Li     typename HalModel     = typename HalPolicy::Model>
ConvertCast(const HalOperation & operation,const HalModel & model,ConversionData & data)140*3e777be0SXin Li bool ConvertCast(const HalOperation& operation,
141*3e777be0SXin Li                  const HalModel& model,
142*3e777be0SXin Li                  ConversionData& data)
143*3e777be0SXin Li {
144*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
145*3e777be0SXin Li 
146*3e777be0SXin Li     ALOGV("HalPolicy::ConvertCast()");
147*3e777be0SXin Li 
148*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
149*3e777be0SXin Li 
150*3e777be0SXin Li     if (!input.IsValid())
151*3e777be0SXin Li     {
152*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
153*3e777be0SXin Li     }
154*3e777be0SXin Li 
155*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
156*3e777be0SXin Li     if (!output)
157*3e777be0SXin Li     {
158*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
159*3e777be0SXin Li     }
160*3e777be0SXin Li 
161*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
162*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
163*3e777be0SXin Li 
164*3e777be0SXin Li     bool isSupported = false;
165*3e777be0SXin Li     armnn::BackendId setBackend;
166*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
167*3e777be0SXin Li     {
168*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
169*3e777be0SXin Li                                    IsCastSupported,
170*3e777be0SXin Li                                    data.m_Backends,
171*3e777be0SXin Li                                    isSupported,
172*3e777be0SXin Li                                    setBackend,
173*3e777be0SXin Li                                    inputInfo,
174*3e777be0SXin Li                                    outputInfo);
175*3e777be0SXin Li     };
176*3e777be0SXin Li 
177*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
178*3e777be0SXin Li     {
179*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
180*3e777be0SXin Li     }
181*3e777be0SXin Li     else
182*3e777be0SXin Li     {
183*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
184*3e777be0SXin Li     }
185*3e777be0SXin Li 
186*3e777be0SXin Li     if (!isSupported)
187*3e777be0SXin Li     {
188*3e777be0SXin Li         return false;
189*3e777be0SXin Li     }
190*3e777be0SXin Li 
191*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddCastLayer();
192*3e777be0SXin Li     layer->SetBackendId(setBackend);
193*3e777be0SXin Li     if (!layer)
194*3e777be0SXin Li     {
195*3e777be0SXin Li         return Fail("%s: Could not add the CastLayer", __func__);
196*3e777be0SXin Li     }
197*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
198*3e777be0SXin Li 
199*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
200*3e777be0SXin Li }
201*3e777be0SXin Li 
202*3e777be0SXin Li template<typename HalPolicy,
203*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
204*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertChannelShuffle(const HalOperation & operation,const HalModel & model,ConversionData & data)205*3e777be0SXin Li bool ConvertChannelShuffle(const HalOperation& operation,
206*3e777be0SXin Li                            const HalModel& model,
207*3e777be0SXin Li                            ConversionData& data)
208*3e777be0SXin Li {
209*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
210*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
211*3e777be0SXin Li 
212*3e777be0SXin Li     ALOGV("HalPolicy::ConvertChannelShuffle()");
213*3e777be0SXin Li 
214*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
215*3e777be0SXin Li     if (!input.IsValid())
216*3e777be0SXin Li     {
217*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
218*3e777be0SXin Li     }
219*3e777be0SXin Li     auto inputDimensions = static_cast<int32_t>(input.GetTensorInfo().GetNumDimensions());
220*3e777be0SXin Li 
221*3e777be0SXin Li     ChannelShuffleDescriptor descriptor;
222*3e777be0SXin Li 
223*3e777be0SXin Li     int32_t groups;
224*3e777be0SXin Li     if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, groups, model, data))
225*3e777be0SXin Li     {
226*3e777be0SXin Li         return Fail("%s: Operation has invalid or unsupported number of groups operand", __func__);
227*3e777be0SXin Li     }
228*3e777be0SXin Li     descriptor.m_NumGroups = static_cast<uint32_t>(groups);
229*3e777be0SXin Li 
230*3e777be0SXin Li     int32_t axis;
231*3e777be0SXin Li     if (!GetInputScalar<HalPolicy>(operation, 2, HalOperandType::INT32, axis, model, data))
232*3e777be0SXin Li     {
233*3e777be0SXin Li         return Fail("%s: Operation has invalid or unsupported dimension channel shuffle operand", __func__);
234*3e777be0SXin Li     }
235*3e777be0SXin Li     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
236*3e777be0SXin Li     {
237*3e777be0SXin Li         return Fail("%s: Operation has invalid dimension: %d. It is out of bounds [-%d, %d))", __func__, axis,
238*3e777be0SXin Li                     inputDimensions, inputDimensions);
239*3e777be0SXin Li     }
240*3e777be0SXin Li     int positiveAxis = (axis < 0) ? inputDimensions + axis : axis;
241*3e777be0SXin Li     descriptor.m_Axis = static_cast<uint32_t>(positiveAxis);
242*3e777be0SXin Li 
243*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
244*3e777be0SXin Li     if (!output)
245*3e777be0SXin Li     {
246*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
247*3e777be0SXin Li     }
248*3e777be0SXin Li 
249*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
250*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
251*3e777be0SXin Li 
252*3e777be0SXin Li     bool isSupported = false;
253*3e777be0SXin Li     armnn::BackendId setBackend;
254*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
255*3e777be0SXin Li     {
256*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
257*3e777be0SXin Li                                    IsChannelShuffleSupported,
258*3e777be0SXin Li                                    data.m_Backends,
259*3e777be0SXin Li                                    isSupported,
260*3e777be0SXin Li                                    setBackend,
261*3e777be0SXin Li                                    inputInfo,
262*3e777be0SXin Li                                    outputInfo,
263*3e777be0SXin Li                                    descriptor);
264*3e777be0SXin Li     };
265*3e777be0SXin Li 
266*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
267*3e777be0SXin Li     {
268*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
269*3e777be0SXin Li     }
270*3e777be0SXin Li     else
271*3e777be0SXin Li     {
272*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
273*3e777be0SXin Li     }
274*3e777be0SXin Li 
275*3e777be0SXin Li     if (!isSupported)
276*3e777be0SXin Li     {
277*3e777be0SXin Li         return false;
278*3e777be0SXin Li     }
279*3e777be0SXin Li 
280*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddChannelShuffleLayer(descriptor);
281*3e777be0SXin Li     layer->SetBackendId(setBackend);
282*3e777be0SXin Li     assert(layer != nullptr);
283*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
284*3e777be0SXin Li 
285*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
286*3e777be0SXin Li }
287*3e777be0SXin Li 
288*3e777be0SXin Li template<typename HalPolicy,
289*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
290*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertComparison_1_2(const HalOperation & operation,const HalModel & model,ConversionData & data,ComparisonOperation comparisonOperation)291*3e777be0SXin Li bool ConvertComparison_1_2(const HalOperation& operation,
292*3e777be0SXin Li                            const HalModel& model,
293*3e777be0SXin Li                            ConversionData& data,
294*3e777be0SXin Li                            ComparisonOperation comparisonOperation)
295*3e777be0SXin Li {
296*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
297*3e777be0SXin Li 
298*3e777be0SXin Li     ALOGV("HalPolicy::ConvertComparison()");
299*3e777be0SXin Li     ALOGV("comparisonOperation = %s", GetComparisonOperationAsCString(comparisonOperation));
300*3e777be0SXin Li 
301*3e777be0SXin Li     LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
302*3e777be0SXin Li     LayerInputHandle input1 = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
303*3e777be0SXin Li 
304*3e777be0SXin Li     if (!(input0.IsValid() && input1.IsValid()))
305*3e777be0SXin Li     {
306*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
307*3e777be0SXin Li     }
308*3e777be0SXin Li 
309*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
310*3e777be0SXin Li     if (!output)
311*3e777be0SXin Li     {
312*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
313*3e777be0SXin Li     }
314*3e777be0SXin Li 
315*3e777be0SXin Li     const TensorInfo& inputInfo0 = input0.GetTensorInfo();
316*3e777be0SXin Li     const TensorInfo& inputInfo1 = input1.GetTensorInfo();
317*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
318*3e777be0SXin Li 
319*3e777be0SXin Li     ComparisonDescriptor descriptor(comparisonOperation);
320*3e777be0SXin Li 
321*3e777be0SXin Li     bool isSupported = false;
322*3e777be0SXin Li     armnn::BackendId setBackend;
323*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
324*3e777be0SXin Li     {
325*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
326*3e777be0SXin Li                                    IsComparisonSupported,
327*3e777be0SXin Li                                    data.m_Backends,
328*3e777be0SXin Li                                    isSupported,
329*3e777be0SXin Li                                    setBackend,
330*3e777be0SXin Li                                    inputInfo0,
331*3e777be0SXin Li                                    inputInfo1,
332*3e777be0SXin Li                                    outputInfo,
333*3e777be0SXin Li                                    descriptor);
334*3e777be0SXin Li 
335*3e777be0SXin Li     };
336*3e777be0SXin Li 
337*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
338*3e777be0SXin Li     {
339*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
340*3e777be0SXin Li     }
341*3e777be0SXin Li     else
342*3e777be0SXin Li     {
343*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
344*3e777be0SXin Li     }
345*3e777be0SXin Li 
346*3e777be0SXin Li     if (!isSupported)
347*3e777be0SXin Li     {
348*3e777be0SXin Li         return false;
349*3e777be0SXin Li     }
350*3e777be0SXin Li 
351*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddComparisonLayer(descriptor);
352*3e777be0SXin Li     layer->SetBackendId(setBackend);
353*3e777be0SXin Li     if (!layer)
354*3e777be0SXin Li     {
355*3e777be0SXin Li         return Fail("%s: Could not add the ComparisonLayer", __func__);
356*3e777be0SXin Li     }
357*3e777be0SXin Li 
358*3e777be0SXin Li     bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
359*3e777be0SXin Li     if (!isReshapeSupported)
360*3e777be0SXin Li     {
361*3e777be0SXin Li         return false;
362*3e777be0SXin Li     }
363*3e777be0SXin Li 
364*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
365*3e777be0SXin Li }
366*3e777be0SXin Li 
367*3e777be0SXin Li template<typename HalPolicy,
368*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
369*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertConv2d_1_2(const HalOperation & operation,const HalModel & model,ConversionData & data)370*3e777be0SXin Li bool ConvertConv2d_1_2(const HalOperation& operation, const HalModel& model, ConversionData& data)
371*3e777be0SXin Li {
372*3e777be0SXin Li 
373*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
374*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
375*3e777be0SXin Li 
376*3e777be0SXin Li     ALOGV("HalPolicy::ConvertConv2d_1_2()");
377*3e777be0SXin Li 
378*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
379*3e777be0SXin Li     if (!input.IsValid())
380*3e777be0SXin Li     {
381*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
382*3e777be0SXin Li     }
383*3e777be0SXin Li 
384*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
385*3e777be0SXin Li     if (!output)
386*3e777be0SXin Li     {
387*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
388*3e777be0SXin Li     }
389*3e777be0SXin Li 
390*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
391*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
392*3e777be0SXin Li 
393*3e777be0SXin Li     Convolution2dDescriptor desc;
394*3e777be0SXin Li     desc.m_DataLayout = DataLayout::NHWC;
395*3e777be0SXin Li 
396*3e777be0SXin Li     // Determine whether padding is implicit or explicit
397*3e777be0SXin Li     bool implicitPadding = operation.inputs.size() == 7 ||
398*3e777be0SXin Li                            (operation.inputs.size() >= 8 &&
399*3e777be0SXin Li                             GetInputOperand<HalPolicy>(operation, 7, model)->type == HalOperandType::BOOL);
400*3e777be0SXin Li 
401*3e777be0SXin Li     if (implicitPadding)
402*3e777be0SXin Li     {
403*3e777be0SXin Li         desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 7, model, data);
404*3e777be0SXin Li     }
405*3e777be0SXin Li     else if (operation.inputs.size() >= 10)
406*3e777be0SXin Li     {
407*3e777be0SXin Li         desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 10, model, data);
408*3e777be0SXin Li     }
409*3e777be0SXin Li 
410*3e777be0SXin Li     const PermutationVector OHWIToOIHW = {0, 2, 3, 1};
411*3e777be0SXin Li 
412*3e777be0SXin Li     // ArmNN does not currently support non-fixed weights or bias
413*3e777be0SXin Li     // The NNAPI filter is always OHWI [depth_out, filter_height, filter_width, depth_in] but ArmNN expects the
414*3e777be0SXin Li     // filter's height and width indices to match the input's height and width indices so we permute it to OIHW if
415*3e777be0SXin Li     // the DataLayout is NCHW
416*3e777be0SXin Li 
417*3e777be0SXin Li 
418*3e777be0SXin Li     if (!IsWeightsValid<HalPolicy>(operation, 1, model) && desc.m_DataLayout == DataLayout::NCHW)
419*3e777be0SXin Li     {
420*3e777be0SXin Li         return Fail("%s: Operation has unsupported weights HalOperandLifeTime", __func__);
421*3e777be0SXin Li     }
422*3e777be0SXin Li 
423*3e777be0SXin Li     LayerInputHandle weightsInput = (desc.m_DataLayout == DataLayout::NCHW) ?
424*3e777be0SXin Li                                      ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data, OHWIToOIHW) :
425*3e777be0SXin Li                                      ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
426*3e777be0SXin Li 
427*3e777be0SXin Li     if (!weightsInput.IsValid())
428*3e777be0SXin Li     {
429*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
430*3e777be0SXin Li     }
431*3e777be0SXin Li 
432*3e777be0SXin Li     LayerInputHandle biasInput = ConvertToLayerInputHandle<HalPolicy>(operation, 2, model, data); // 1D
433*3e777be0SXin Li     if (!biasInput.IsValid())
434*3e777be0SXin Li     {
435*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
436*3e777be0SXin Li     }
437*3e777be0SXin Li 
438*3e777be0SXin Li     biasInput.SanitizeQuantizationScale(weightsInput, input);
439*3e777be0SXin Li     armnn::TensorInfo weightsInfo = weightsInput.GetTensorInfo();
440*3e777be0SXin Li     armnn::TensorInfo biasInfo = biasInput.GetTensorInfo();
441*3e777be0SXin Li 
442*3e777be0SXin Li     ActivationFn activation;
443*3e777be0SXin Li 
444*3e777be0SXin Li     if (implicitPadding)
445*3e777be0SXin Li     {
446*3e777be0SXin Li         android::nn::PaddingScheme paddingScheme;
447*3e777be0SXin Li         if (!GetInputPaddingScheme<HalPolicy>(operation, 3, paddingScheme, model, data) ||
448*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_StrideX, model, data) ||
449*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideY, model, data) ||
450*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation, 6, activation, model, data) ||
451*3e777be0SXin Li             !GetOptionalConvolutionDilationParams<HalPolicy>(operation, 8, desc, model, data))
452*3e777be0SXin Li         {
453*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (implicit padding)", __func__);
454*3e777be0SXin Li         }
455*3e777be0SXin Li 
456*3e777be0SXin Li         armnnUtils::DataLayoutIndexed dataLayoutIndexed(desc.m_DataLayout);
457*3e777be0SXin Li         unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
458*3e777be0SXin Li         unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
459*3e777be0SXin Li         const uint32_t kernelX = weightsInfo.GetShape()[widthIndex];
460*3e777be0SXin Li         const uint32_t kernelY = weightsInfo.GetShape()[heightIndex];
461*3e777be0SXin Li         const uint32_t inputX  = inputInfo.GetShape()[widthIndex];
462*3e777be0SXin Li         const uint32_t inputY  = inputInfo.GetShape()[heightIndex];
463*3e777be0SXin Li 
464*3e777be0SXin Li         CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, paddingScheme);
465*3e777be0SXin Li         CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, paddingScheme);
466*3e777be0SXin Li 
467*3e777be0SXin Li     }
468*3e777be0SXin Li     else if (operation.inputs.size() >= 10)
469*3e777be0SXin Li     {
470*3e777be0SXin Li         // explicit padding
471*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadLeft, model, data) ||
472*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadRight, model, data) ||
473*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_PadTop, model, data) ||
474*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_PadBottom, model, data) ||
475*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_StrideX, model, data) ||
476*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_StrideY, model, data) ||
477*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation, 9, activation, model, data) ||
478*3e777be0SXin Li             !GetOptionalConvolutionDilationParams<HalPolicy>(operation, 11, desc, model, data))
479*3e777be0SXin Li         {
480*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (explicit padding)", __func__);
481*3e777be0SXin Li         }
482*3e777be0SXin Li     }
483*3e777be0SXin Li     else
484*3e777be0SXin Li     {
485*3e777be0SXin Li         return Fail("%s: Unsupported number of operation inputs", __func__);
486*3e777be0SXin Li     }
487*3e777be0SXin Li 
488*3e777be0SXin Li     desc.m_BiasEnabled = true;
489*3e777be0SXin Li     Optional<TensorInfo> biases(biasInfo);
490*3e777be0SXin Li 
491*3e777be0SXin Li     bool isSupported = false;
492*3e777be0SXin Li     armnn::BackendId setBackend;
493*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
494*3e777be0SXin Li     {
495*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
496*3e777be0SXin Li                                    IsConvolution2dSupported,
497*3e777be0SXin Li                                    data.m_Backends,
498*3e777be0SXin Li                                    isSupported,
499*3e777be0SXin Li                                    setBackend,
500*3e777be0SXin Li                                    inputInfo,
501*3e777be0SXin Li                                    outputInfo,
502*3e777be0SXin Li                                    desc,
503*3e777be0SXin Li                                    weightsInfo,
504*3e777be0SXin Li                                    biases);
505*3e777be0SXin Li     };
506*3e777be0SXin Li 
507*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
508*3e777be0SXin Li     {
509*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
510*3e777be0SXin Li     }
511*3e777be0SXin Li     else
512*3e777be0SXin Li     {
513*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
514*3e777be0SXin Li     }
515*3e777be0SXin Li 
516*3e777be0SXin Li     if (!isSupported)
517*3e777be0SXin Li     {
518*3e777be0SXin Li         return false;
519*3e777be0SXin Li     }
520*3e777be0SXin Li 
521*3e777be0SXin Li     armnn::IConnectableLayer* startLayer = data.m_Network->AddConvolution2dLayer(desc);
522*3e777be0SXin Li     startLayer->SetBackendId(setBackend);
523*3e777be0SXin Li 
524*3e777be0SXin Li     if (!startLayer)
525*3e777be0SXin Li     {
526*3e777be0SXin Li         return Fail("%s: AddConvolution2dLayer failed", __func__);
527*3e777be0SXin Li     }
528*3e777be0SXin Li 
529*3e777be0SXin Li     input.Connect(startLayer->GetInputSlot(0));
530*3e777be0SXin Li     weightsInput.Connect(startLayer->GetInputSlot(1));
531*3e777be0SXin Li     biasInput.Connect(startLayer->GetInputSlot(2));
532*3e777be0SXin Li 
533*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *startLayer, model,
534*3e777be0SXin Li                                                    data, nullptr, validateFunc, activation);
535*3e777be0SXin Li }
536*3e777be0SXin Li 
537*3e777be0SXin Li template<typename HalPolicy,
538*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
539*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertDepthwiseConv2d_1_2(const HalOperation & operation,const HalModel & model,ConversionData & data)540*3e777be0SXin Li bool ConvertDepthwiseConv2d_1_2(const HalOperation& operation, const HalModel& model, ConversionData& data)
541*3e777be0SXin Li {
542*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
543*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
544*3e777be0SXin Li 
545*3e777be0SXin Li     ALOGV("HalPolicy::ConvertDepthwiseConv2d_1_2()");
546*3e777be0SXin Li 
547*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
548*3e777be0SXin Li 
549*3e777be0SXin Li     if (!input.IsValid())
550*3e777be0SXin Li     {
551*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
552*3e777be0SXin Li     }
553*3e777be0SXin Li 
554*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
555*3e777be0SXin Li 
556*3e777be0SXin Li     if (!output)
557*3e777be0SXin Li     {
558*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
559*3e777be0SXin Li     }
560*3e777be0SXin Li 
561*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
562*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
563*3e777be0SXin Li 
564*3e777be0SXin Li     // ArmNN does not currently support non-fixed weights or bias
565*3e777be0SXin Li     // Find the shape of the weights tensor. In AndroidNN this will be [ 1, H, W, I * M ]
566*3e777be0SXin Li     const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, 1, model);
567*3e777be0SXin Li     if (!weightsOperand)
568*3e777be0SXin Li     {
569*3e777be0SXin Li         return Fail("%s: Could not read weights", __func__);
570*3e777be0SXin Li     }
571*3e777be0SXin Li     if (weightsOperand->dimensions[0] != 1)
572*3e777be0SXin Li     {
573*3e777be0SXin Li         return Fail("%s: Invalid weights; for depthwise convolution, dimension 0 must be 1 but it is %i",
574*3e777be0SXin Li                     __func__, weightsOperand->dimensions[0] );
575*3e777be0SXin Li     }
576*3e777be0SXin Li 
577*3e777be0SXin Li     DepthwiseConvolution2dDescriptor desc;
578*3e777be0SXin Li     desc.m_DataLayout = DataLayout::NHWC;
579*3e777be0SXin Li 
580*3e777be0SXin Li     // Determine whether padding is implicit or explicit
581*3e777be0SXin Li     bool implicitPadding = operation.inputs.size() == 8 ||
582*3e777be0SXin Li                            (operation.inputs.size() >= 9 &&
583*3e777be0SXin Li                             GetInputOperand<HalPolicy>(operation, 8, model)->type == HalOperandType::BOOL);
584*3e777be0SXin Li 
585*3e777be0SXin Li     // Look ahead to find the optional DataLayout, if present
586*3e777be0SXin Li     const uint32_t dataLayoutFlagIndex = implicitPadding ? 8 : 11;
587*3e777be0SXin Li     desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, dataLayoutFlagIndex, model, data);
588*3e777be0SXin Li 
589*3e777be0SXin Li     armnnUtils::DataLayoutIndexed dataLayoutIndexed(desc.m_DataLayout);
590*3e777be0SXin Li     unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
591*3e777be0SXin Li     unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
592*3e777be0SXin Li 
593*3e777be0SXin Li     LayerInputHandle weightsInput = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
594*3e777be0SXin Li     if (!weightsInput.IsValid())
595*3e777be0SXin Li     {
596*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
597*3e777be0SXin Li     }
598*3e777be0SXin Li 
599*3e777be0SXin Li     const HalOperand* biasOperand = GetInputOperand<HalPolicy>(operation, 2, model);
600*3e777be0SXin Li     if (!biasOperand)
601*3e777be0SXin Li     {
602*3e777be0SXin Li         return Fail("%s: Could not read bias", __func__);
603*3e777be0SXin Li     }
604*3e777be0SXin Li 
605*3e777be0SXin Li     LayerInputHandle biasInput = ConvertToLayerInputHandle<HalPolicy>(operation, 2, model, data); // 1D
606*3e777be0SXin Li     if (!biasInput.IsValid())
607*3e777be0SXin Li     {
608*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
609*3e777be0SXin Li     }
610*3e777be0SXin Li 
611*3e777be0SXin Li     biasInput.SanitizeQuantizationScale(weightsInput, input);
612*3e777be0SXin Li     armnn::TensorInfo weightsInfo = weightsInput.GetTensorInfo();
613*3e777be0SXin Li     armnn::TensorInfo biasInfo = biasInput.GetTensorInfo();
614*3e777be0SXin Li 
615*3e777be0SXin Li     ActivationFn activation;
616*3e777be0SXin Li 
617*3e777be0SXin Li     if (implicitPadding)
618*3e777be0SXin Li     {
619*3e777be0SXin Li         android::nn::PaddingScheme paddingScheme;
620*3e777be0SXin Li         if (!GetInputPaddingScheme<HalPolicy>(operation, 3, paddingScheme, model, data) ||
621*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_StrideX, model, data) ||
622*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideY, model, data) ||
623*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation, 7, activation, model, data) ||
624*3e777be0SXin Li             !GetOptionalConvolutionDilationParams<HalPolicy>(operation, 9, desc, model, data))
625*3e777be0SXin Li         {
626*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (implicit padding)", __func__);
627*3e777be0SXin Li         }
628*3e777be0SXin Li 
629*3e777be0SXin Li         const uint32_t kernelX = weightsInfo.GetShape()[2];
630*3e777be0SXin Li         const uint32_t kernelY = weightsInfo.GetShape()[1];
631*3e777be0SXin Li         const uint32_t inputX  = inputInfo.GetShape()[widthIndex];
632*3e777be0SXin Li         const uint32_t inputY  = inputInfo.GetShape()[heightIndex];
633*3e777be0SXin Li 
634*3e777be0SXin Li         CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, paddingScheme);
635*3e777be0SXin Li         CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, paddingScheme);
636*3e777be0SXin Li     }
637*3e777be0SXin Li     else if (operation.inputs.size() >= 11)
638*3e777be0SXin Li     {
639*3e777be0SXin Li         // explicit padding
640*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadLeft, model, data) ||
641*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadRight, model, data) ||
642*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_PadTop, model, data) ||
643*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_PadBottom, model, data) ||
644*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_StrideX, model, data) ||
645*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_StrideY, model, data) ||
646*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation,  10, activation, model, data) ||
647*3e777be0SXin Li             !GetOptionalConvolutionDilationParams<HalPolicy>(operation, 12, desc, model, data))
648*3e777be0SXin Li         {
649*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (explicit padding)", __func__);
650*3e777be0SXin Li         }
651*3e777be0SXin Li     }
652*3e777be0SXin Li     else
653*3e777be0SXin Li     {
654*3e777be0SXin Li         return Fail("%s: Unsupported number of operation inputs", __func__);
655*3e777be0SXin Li     }
656*3e777be0SXin Li 
657*3e777be0SXin Li     desc.m_BiasEnabled = true;
658*3e777be0SXin Li     Optional<TensorInfo> biases(biasInfo);
659*3e777be0SXin Li 
660*3e777be0SXin Li     bool isSupported = false;
661*3e777be0SXin Li     armnn::BackendId setBackend;
662*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
663*3e777be0SXin Li     {
664*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
665*3e777be0SXin Li                                    IsDepthwiseConvolutionSupported,
666*3e777be0SXin Li                                    data.m_Backends,
667*3e777be0SXin Li                                    isSupported,
668*3e777be0SXin Li                                    setBackend,
669*3e777be0SXin Li                                    inputInfo,
670*3e777be0SXin Li                                    outputInfo,
671*3e777be0SXin Li                                    desc,
672*3e777be0SXin Li                                    weightsInfo,
673*3e777be0SXin Li                                    biases);
674*3e777be0SXin Li     };
675*3e777be0SXin Li 
676*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
677*3e777be0SXin Li     {
678*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
679*3e777be0SXin Li     }
680*3e777be0SXin Li     else
681*3e777be0SXin Li     {
682*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
683*3e777be0SXin Li     }
684*3e777be0SXin Li 
685*3e777be0SXin Li     if (!isSupported)
686*3e777be0SXin Li     {
687*3e777be0SXin Li         return false;
688*3e777be0SXin Li     }
689*3e777be0SXin Li 
690*3e777be0SXin Li     armnn::IConnectableLayer* startLayer = data.m_Network->AddDepthwiseConvolution2dLayer(desc);
691*3e777be0SXin Li     startLayer->SetBackendId(setBackend);
692*3e777be0SXin Li 
693*3e777be0SXin Li     if (!startLayer)
694*3e777be0SXin Li     {
695*3e777be0SXin Li         return Fail("%s: AddDepthwiseConvolution2dLayer failed", __func__);
696*3e777be0SXin Li     }
697*3e777be0SXin Li 
698*3e777be0SXin Li     input.Connect(startLayer->GetInputSlot(0));
699*3e777be0SXin Li 
700*3e777be0SXin Li     // Connect weights and bias inputs
701*3e777be0SXin Li     weightsInput.Connect(startLayer->GetInputSlot(1));
702*3e777be0SXin Li     biasInput.Connect(startLayer->GetInputSlot(2));
703*3e777be0SXin Li 
704*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *startLayer, model,
705*3e777be0SXin Li                                                    data, nullptr, validateFunc, activation);
706*3e777be0SXin Li }
707*3e777be0SXin Li 
708*3e777be0SXin Li template<typename HalPolicy,
709*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
710*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertDequantize_1_2(const HalOperation & operation,const HalModel & model,ConversionData & data)711*3e777be0SXin Li bool ConvertDequantize_1_2(const HalOperation& operation, const HalModel& model, ConversionData& data)
712*3e777be0SXin Li {
713*3e777be0SXin Li     ALOGV("HalPolicy::ConvertDequantize()");
714*3e777be0SXin Li 
715*3e777be0SXin Li     if (IsQSymmDequantizeForWeights<HalPolicy>(operation, model))
716*3e777be0SXin Li     {
717*3e777be0SXin Li         // NOTE: QSymm8 weights are dequantized internally by the driver,
718*3e777be0SXin Li         // therefore this type of Dequantize is implicitly supported
719*3e777be0SXin Li         return true;
720*3e777be0SXin Li     }
721*3e777be0SXin Li 
722*3e777be0SXin Li     return ::ConvertDequantize<HalPolicy>(operation, model, data);
723*3e777be0SXin Li }
724*3e777be0SXin Li 
725*3e777be0SXin Li template<typename HalPolicy,
726*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
727*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertElementwiseUnary(const HalOperation & operation,const HalModel & model,ConversionData & data,UnaryOperation unaryOperation)728*3e777be0SXin Li bool ConvertElementwiseUnary(const HalOperation& operation,
729*3e777be0SXin Li                              const HalModel& model,
730*3e777be0SXin Li                              ConversionData& data,
731*3e777be0SXin Li                              UnaryOperation unaryOperation)
732*3e777be0SXin Li {
733*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
734*3e777be0SXin Li 
735*3e777be0SXin Li     ALOGV("HalPolicy::ConvertElementwiseUnary()");
736*3e777be0SXin Li     ALOGV("unaryOperation = %s", GetUnaryOperationAsCString(unaryOperation));
737*3e777be0SXin Li 
738*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
739*3e777be0SXin Li 
740*3e777be0SXin Li     if (!input.IsValid())
741*3e777be0SXin Li     {
742*3e777be0SXin Li         return Fail("%s: Operation has invalid input", __func__);
743*3e777be0SXin Li     }
744*3e777be0SXin Li 
745*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
746*3e777be0SXin Li     if (!output)
747*3e777be0SXin Li     {
748*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
749*3e777be0SXin Li     }
750*3e777be0SXin Li 
751*3e777be0SXin Li     const TensorInfo& inputInfo = input.GetTensorInfo();
752*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
753*3e777be0SXin Li 
754*3e777be0SXin Li     ElementwiseUnaryDescriptor descriptor(unaryOperation);
755*3e777be0SXin Li 
756*3e777be0SXin Li     bool isSupported = false;
757*3e777be0SXin Li     armnn::BackendId setBackend;
758*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
759*3e777be0SXin Li     {
760*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
761*3e777be0SXin Li                                    IsElementwiseUnarySupported,
762*3e777be0SXin Li                                    data.m_Backends,
763*3e777be0SXin Li                                    isSupported,
764*3e777be0SXin Li                                    setBackend,
765*3e777be0SXin Li                                    inputInfo,
766*3e777be0SXin Li                                    outputInfo,
767*3e777be0SXin Li                                    descriptor);
768*3e777be0SXin Li     };
769*3e777be0SXin Li 
770*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
771*3e777be0SXin Li     {
772*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
773*3e777be0SXin Li     }
774*3e777be0SXin Li     else
775*3e777be0SXin Li     {
776*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
777*3e777be0SXin Li     }
778*3e777be0SXin Li 
779*3e777be0SXin Li     if (!isSupported)
780*3e777be0SXin Li     {
781*3e777be0SXin Li         return false;
782*3e777be0SXin Li     }
783*3e777be0SXin Li 
784*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddElementwiseUnaryLayer(descriptor);
785*3e777be0SXin Li     layer->SetBackendId(setBackend);
786*3e777be0SXin Li     if (!layer)
787*3e777be0SXin Li     {
788*3e777be0SXin Li         return Fail("%s: Could not add the ElementwiseUnaryLayer", __func__);
789*3e777be0SXin Li     }
790*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
791*3e777be0SXin Li 
792*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
793*3e777be0SXin Li }
794*3e777be0SXin Li 
795*3e777be0SXin Li template<typename HalPolicy,
796*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
797*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertExpandDims(const HalOperation & operation,const HalModel & model,ConversionData & data)798*3e777be0SXin Li bool ConvertExpandDims(const HalOperation& operation, const HalModel& model, ConversionData& data)
799*3e777be0SXin Li {
800*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
801*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
802*3e777be0SXin Li 
803*3e777be0SXin Li     ALOGV("HalPolicy::ConvertExpandDims()");
804*3e777be0SXin Li 
805*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
806*3e777be0SXin Li 
807*3e777be0SXin Li     if (!input.IsValid())
808*3e777be0SXin Li     {
809*3e777be0SXin Li         return Fail("%s: Operation has invalid input", __func__);
810*3e777be0SXin Li     }
811*3e777be0SXin Li 
812*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
813*3e777be0SXin Li     if (!output)
814*3e777be0SXin Li     {
815*3e777be0SXin Li         return Fail("%s: Operation has invalid output", __func__);
816*3e777be0SXin Li     }
817*3e777be0SXin Li 
818*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
819*3e777be0SXin Li 
820*3e777be0SXin Li     int32_t axis;
821*3e777be0SXin Li     if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, axis, model, data))
822*3e777be0SXin Li     {
823*3e777be0SXin Li         return Fail("%s: failed to get axis input value", __func__);
824*3e777be0SXin Li     }
825*3e777be0SXin Li 
826*3e777be0SXin Li     TensorShape targetShape;
827*3e777be0SXin Li 
828*3e777be0SXin Li     try
829*3e777be0SXin Li     {
830*3e777be0SXin Li         targetShape = armnnUtils::ExpandDims(input.GetTensorInfo().GetShape(), axis);
831*3e777be0SXin Li     }
832*3e777be0SXin Li     catch (const std::exception& e)
833*3e777be0SXin Li     {
834*3e777be0SXin Li         return Fail("%s: %s", __func__, e.what());
835*3e777be0SXin Li     }
836*3e777be0SXin Li 
837*3e777be0SXin Li     ReshapeDescriptor reshapeDescriptor;
838*3e777be0SXin Li     reshapeDescriptor.m_TargetShape = targetShape;
839*3e777be0SXin Li 
840*3e777be0SXin Li     bool isSupported = false;
841*3e777be0SXin Li     armnn::BackendId setBackend;
842*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
843*3e777be0SXin Li     {
844*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
845*3e777be0SXin Li                                    IsReshapeSupported,
846*3e777be0SXin Li                                    data.m_Backends,
847*3e777be0SXin Li                                    isSupported,
848*3e777be0SXin Li                                    setBackend,
849*3e777be0SXin Li                                    input.GetTensorInfo(),
850*3e777be0SXin Li                                    outputInfo,
851*3e777be0SXin Li                                    reshapeDescriptor);
852*3e777be0SXin Li     };
853*3e777be0SXin Li 
854*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
855*3e777be0SXin Li     {
856*3e777be0SXin Li         if (targetShape != outputInfo.GetShape())
857*3e777be0SXin Li         {
858*3e777be0SXin Li             return Fail("%s: Shape of the output operand does not match the resolved expanded shape", __func__);
859*3e777be0SXin Li         }
860*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
861*3e777be0SXin Li     }
862*3e777be0SXin Li     else
863*3e777be0SXin Li     {
864*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
865*3e777be0SXin Li     }
866*3e777be0SXin Li 
867*3e777be0SXin Li     if (!isSupported)
868*3e777be0SXin Li     {
869*3e777be0SXin Li         return false;
870*3e777be0SXin Li     }
871*3e777be0SXin Li 
872*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddReshapeLayer(reshapeDescriptor);
873*3e777be0SXin Li     layer->SetBackendId(setBackend);
874*3e777be0SXin Li     if (!layer)
875*3e777be0SXin Li     {
876*3e777be0SXin Li         return Fail("%s: Could not add the ReshapeLayer", __func__);
877*3e777be0SXin Li     }
878*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
879*3e777be0SXin Li 
880*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
881*3e777be0SXin Li }
882*3e777be0SXin Li 
883*3e777be0SXin Li template<typename HalPolicy,
884*3e777be0SXin Li         typename HalOperation = typename HalPolicy::Operation,
885*3e777be0SXin Li         typename HalModel     = typename HalPolicy::Model>
ConvertGather(const HalOperation & operation,const HalModel & model,ConversionData & data)886*3e777be0SXin Li bool ConvertGather(const HalOperation& operation, const HalModel& model, ConversionData& data)
887*3e777be0SXin Li {
888*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
889*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
890*3e777be0SXin Li 
891*3e777be0SXin Li     ALOGV("HalPolicy::ConvertGather()");
892*3e777be0SXin Li 
893*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
894*3e777be0SXin Li     if (!input.IsValid())
895*3e777be0SXin Li     {
896*3e777be0SXin Li         return Fail("%s: Operation has invalid input", __func__);
897*3e777be0SXin Li     }
898*3e777be0SXin Li     auto inputDimensions = input.GetTensorInfo().GetNumDimensions();
899*3e777be0SXin Li 
900*3e777be0SXin Li     LayerInputHandle indices = ConvertToLayerInputHandle<HalPolicy>(operation, 2, model, data);
901*3e777be0SXin Li     if (!indices.IsValid())
902*3e777be0SXin Li     {
903*3e777be0SXin Li         return Fail("%s: Operation has invalid indices", __func__);
904*3e777be0SXin Li     }
905*3e777be0SXin Li     auto indicesDimensions = indices.GetTensorInfo().GetNumDimensions();
906*3e777be0SXin Li 
907*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
908*3e777be0SXin Li     if (!output)
909*3e777be0SXin Li     {
910*3e777be0SXin Li         return Fail("%s: Operation has invalid output", __func__);
911*3e777be0SXin Li     }
912*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
913*3e777be0SXin Li     auto outputDimensions = outputInfo.GetNumDimensions();
914*3e777be0SXin Li     if (outputDimensions != inputDimensions + indicesDimensions - 1)
915*3e777be0SXin Li     {
916*3e777be0SXin Li         return Fail("%s: Operation has invalid output dimensions: %d. Output must be an (%d + %d - 1)-D tensor",
917*3e777be0SXin Li                      __func__, outputDimensions, inputDimensions, indicesDimensions);
918*3e777be0SXin Li     }
919*3e777be0SXin Li 
920*3e777be0SXin Li     int32_t axis;
921*3e777be0SXin Li     if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, axis, model, data))
922*3e777be0SXin Li     {
923*3e777be0SXin Li         return Fail("%s: Operation has invalid or unsupported axis operand", __func__);
924*3e777be0SXin Li     }
925*3e777be0SXin Li     int32_t inputDimensions_int = static_cast<int32_t>(inputDimensions);
926*3e777be0SXin Li     if ((axis < -inputDimensions_int) || (inputDimensions_int <= axis))
927*3e777be0SXin Li     {
928*3e777be0SXin Li         return Fail("%s: Operation has invalid axis: %d. It is out of bounds [-%d, %d))", __func__, axis,
929*3e777be0SXin Li                     inputDimensions, inputDimensions);
930*3e777be0SXin Li     }
931*3e777be0SXin Li 
932*3e777be0SXin Li     GatherDescriptor desc;
933*3e777be0SXin Li     desc.m_Axis = axis;
934*3e777be0SXin Li 
935*3e777be0SXin Li     bool isSupported = false;
936*3e777be0SXin Li     armnn::BackendId setBackend;
937*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
938*3e777be0SXin Li     {
939*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
940*3e777be0SXin Li                                    IsGatherSupported,
941*3e777be0SXin Li                                    data.m_Backends,
942*3e777be0SXin Li                                    isSupported,
943*3e777be0SXin Li                                    setBackend,
944*3e777be0SXin Li                                    input.GetTensorInfo(),
945*3e777be0SXin Li                                    indices.GetTensorInfo(),
946*3e777be0SXin Li                                    outputInfo,
947*3e777be0SXin Li                                    desc);
948*3e777be0SXin Li     };
949*3e777be0SXin Li 
950*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
951*3e777be0SXin Li     {
952*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
953*3e777be0SXin Li     }
954*3e777be0SXin Li     else
955*3e777be0SXin Li     {
956*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
957*3e777be0SXin Li     }
958*3e777be0SXin Li 
959*3e777be0SXin Li     if (!isSupported)
960*3e777be0SXin Li     {
961*3e777be0SXin Li         return false;
962*3e777be0SXin Li     }
963*3e777be0SXin Li 
964*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddGatherLayer(desc);
965*3e777be0SXin Li     layer->SetBackendId(setBackend);
966*3e777be0SXin Li     if (!layer)
967*3e777be0SXin Li     {
968*3e777be0SXin Li         return Fail("%s: Could not add the GatherLayer", __func__);
969*3e777be0SXin Li     }
970*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
971*3e777be0SXin Li     indices.Connect(layer->GetInputSlot(1));
972*3e777be0SXin Li 
973*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
974*3e777be0SXin Li }
975*3e777be0SXin Li 
976*3e777be0SXin Li template<typename HalPolicy,
977*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
978*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertGroupedConv2d(const HalOperation & operation,const HalModel & model,ConversionData & data)979*3e777be0SXin Li bool ConvertGroupedConv2d(const HalOperation& operation, const HalModel& model, ConversionData& data)
980*3e777be0SXin Li {
981*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
982*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
983*3e777be0SXin Li 
984*3e777be0SXin Li     ALOGV("HalPolicy::ConvertGroupedConv2d()");
985*3e777be0SXin Li 
986*3e777be0SXin Li     //
987*3e777be0SXin Li     // Parse data
988*3e777be0SXin Li     //
989*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
990*3e777be0SXin Li     if (!input.IsValid())
991*3e777be0SXin Li     {
992*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
993*3e777be0SXin Li     }
994*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
995*3e777be0SXin Li 
996*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
997*3e777be0SXin Li     if (!output)
998*3e777be0SXin Li     {
999*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
1000*3e777be0SXin Li     }
1001*3e777be0SXin Li     TensorInfo outputInfo = GetTensorInfoForOperand(*output);
1002*3e777be0SXin Li 
1003*3e777be0SXin Li     // Look ahead to determine data layout
1004*3e777be0SXin Li     DataLayout dataLayout = DataLayout::NHWC;
1005*3e777be0SXin Li     if (operation.inputs.size() == 12)
1006*3e777be0SXin Li     {
1007*3e777be0SXin Li         dataLayout = OptionalDataLayout<HalPolicy>(operation, 11, model, data);
1008*3e777be0SXin Li     }
1009*3e777be0SXin Li     else
1010*3e777be0SXin Li     {
1011*3e777be0SXin Li         dataLayout = OptionalDataLayout<HalPolicy>(operation, 8, model, data);
1012*3e777be0SXin Li     }
1013*3e777be0SXin Li 
1014*3e777be0SXin Li     // NOTE:
1015*3e777be0SXin Li     // NNAPI weights are always OHWI, i.e. [depth_out, filter_height, filter_width, depth_group],
1016*3e777be0SXin Li     // but Arm NN expects the filter's height and width indices to match the input's height and
1017*3e777be0SXin Li     // width indices so when the DataLayout is NCHW, we need to permute the weights to OIHW
1018*3e777be0SXin Li     const PermutationVector ohwiToOihw = { 0u, 2u, 3u, 1u };
1019*3e777be0SXin Li     const ConstTensorPin weightsPin = (dataLayout == DataLayout::NCHW) ?
1020*3e777be0SXin Li                                       ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1,
1021*3e777be0SXin Li                                                                                        model, data, ohwiToOihw) :
1022*3e777be0SXin Li                                       ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1, model, data);
1023*3e777be0SXin Li     const ConstTensorPin biasesPin  =
1024*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
1025*3e777be0SXin Li     if (!weightsPin.IsValid() || !biasesPin.IsValid())
1026*3e777be0SXin Li     {
1027*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
1028*3e777be0SXin Li     }
1029*3e777be0SXin Li 
1030*3e777be0SXin Li     ConstTensor weights = weightsPin.GetConstTensor();
1031*3e777be0SXin Li     ConstTensor biases  = biasesPin.GetConstTensor();
1032*3e777be0SXin Li     SanitizeBiasQuantizationScale(biases.GetInfo(), weights.GetInfo(), inputInfo);
1033*3e777be0SXin Li 
1034*3e777be0SXin Li     const TensorShape& inputShape   = inputInfo.GetShape();
1035*3e777be0SXin Li     const TensorShape& outputShape  = outputInfo.GetShape();
1036*3e777be0SXin Li     const TensorShape& weightsShape = weights.GetShape();
1037*3e777be0SXin Li 
1038*3e777be0SXin Li     armnnUtils::DataLayoutIndexed dataLayoutIndexed(dataLayout);
1039*3e777be0SXin Li     const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
1040*3e777be0SXin Li     const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
1041*3e777be0SXin Li     const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
1042*3e777be0SXin Li 
1043*3e777be0SXin Li     Convolution2dDescriptor desc;
1044*3e777be0SXin Li     desc.m_DataLayout  = dataLayout;
1045*3e777be0SXin Li     desc.m_BiasEnabled = true;
1046*3e777be0SXin Li 
1047*3e777be0SXin Li     unsigned int numGroups;
1048*3e777be0SXin Li     ActivationFn activation;
1049*3e777be0SXin Li 
1050*3e777be0SXin Li     if (operation.inputs.size() == 12)
1051*3e777be0SXin Li     {
1052*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadLeft, model, data) ||
1053*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadRight, model, data) ||
1054*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_PadTop, model, data) ||
1055*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_PadBottom, model, data) ||
1056*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_StrideX, model, data) ||
1057*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_StrideY, model, data) ||
1058*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 9, HalOperandType::INT32, numGroups, model, data) ||
1059*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation, 10, activation, model, data))
1060*3e777be0SXin Li         {
1061*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (explicit padding)", __func__);
1062*3e777be0SXin Li         }
1063*3e777be0SXin Li 
1064*3e777be0SXin Li     }
1065*3e777be0SXin Li     else if (operation.inputs.size() == 9)
1066*3e777be0SXin Li     {
1067*3e777be0SXin Li         android::nn::PaddingScheme paddingScheme;
1068*3e777be0SXin Li         if (!GetInputPaddingScheme<HalPolicy>(operation, 3, paddingScheme, model, data) ||
1069*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_StrideX, model, data) ||
1070*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_StrideY, model, data) ||
1071*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, numGroups, model, data) ||
1072*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation, 7, activation, model, data))
1073*3e777be0SXin Li         {
1074*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (implicit padding)", __func__);
1075*3e777be0SXin Li         }
1076*3e777be0SXin Li 
1077*3e777be0SXin Li         const uint32_t inputX = inputInfo.GetShape()[widthIndex];
1078*3e777be0SXin Li         const uint32_t inputY = inputInfo.GetShape()[heightIndex];
1079*3e777be0SXin Li 
1080*3e777be0SXin Li         const uint32_t kernelX = weightsShape[widthIndex];
1081*3e777be0SXin Li         const uint32_t kernelY = weightsShape[heightIndex];
1082*3e777be0SXin Li 
1083*3e777be0SXin Li         CalcPadding(inputX, kernelX, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, paddingScheme);
1084*3e777be0SXin Li         CalcPadding(inputY, kernelY, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, paddingScheme);
1085*3e777be0SXin Li     }
1086*3e777be0SXin Li     else
1087*3e777be0SXin Li     {
1088*3e777be0SXin Li         return Fail("%s: Unsupported number of operation inputs", __func__);
1089*3e777be0SXin Li     }
1090*3e777be0SXin Li 
1091*3e777be0SXin Li     // Equivalent to outputShape[channelsIndex], but we can't know the outputShape in the case of dynamic tensors
1092*3e777be0SXin Li     const unsigned int outputChannels = weightsShape[0];
1093*3e777be0SXin Li 
1094*3e777be0SXin Li     const unsigned int channelsPerGroup  = weightsShape[channelsIndex];
1095*3e777be0SXin Li     const unsigned int channelMultiplier = outputChannels / numGroups;
1096*3e777be0SXin Li 
1097*3e777be0SXin Li     //
1098*3e777be0SXin Li     // Validate all relevant inputs
1099*3e777be0SXin Li     //
1100*3e777be0SXin Li     if (numGroups <= 0)
1101*3e777be0SXin Li     {
1102*3e777be0SXin Li         return Fail("%s: Number of groups must be greater than 0. Got: %d", __func__, numGroups);
1103*3e777be0SXin Li     }
1104*3e777be0SXin Li 
1105*3e777be0SXin Li     if (outputChannels % numGroups != 0u)
1106*3e777be0SXin Li     {
1107*3e777be0SXin Li         return Fail("%s: Output channels must be divisible by the number of groups", __func__);
1108*3e777be0SXin Li     }
1109*3e777be0SXin Li 
1110*3e777be0SXin Li     //
1111*3e777be0SXin Li     // Set up Splitter layer
1112*3e777be0SXin Li     //
1113*3e777be0SXin Li     unsigned int splitterDimSizes[4] = { inputShape[0], inputShape[1], inputShape[2], inputShape[3] };
1114*3e777be0SXin Li     splitterDimSizes[channelsIndex] /= numGroups; // split in depth
1115*3e777be0SXin Li 
1116*3e777be0SXin Li     TensorInfo splitterOutputInfo(4,
1117*3e777be0SXin Li                                   splitterDimSizes,
1118*3e777be0SXin Li                                   inputInfo.GetDataType(),
1119*3e777be0SXin Li                                   inputInfo.GetQuantizationScale(),
1120*3e777be0SXin Li                                   inputInfo.GetQuantizationOffset());
1121*3e777be0SXin Li 
1122*3e777be0SXin Li     std::vector<std::reference_wrapper<TensorInfo>> splitterOutputInfos(numGroups, std::ref(splitterOutputInfo));
1123*3e777be0SXin Li 
1124*3e777be0SXin Li     ViewsDescriptor splitterDesc(numGroups);
1125*3e777be0SXin Li     for (unsigned int group = 0u; group < numGroups; ++group)
1126*3e777be0SXin Li     {
1127*3e777be0SXin Li         splitterDesc.SetViewOriginCoord(group, channelsIndex, splitterDimSizes[channelsIndex] * group);
1128*3e777be0SXin Li         for (unsigned int dimIdx = 0u; dimIdx < 4u; dimIdx++)
1129*3e777be0SXin Li         {
1130*3e777be0SXin Li             splitterDesc.SetViewSize(group, dimIdx, splitterDimSizes[dimIdx]);
1131*3e777be0SXin Li         }
1132*3e777be0SXin Li     }
1133*3e777be0SXin Li 
1134*3e777be0SXin Li     bool isSupported = false;
1135*3e777be0SXin Li     armnn::BackendId setBackendSplit;
1136*3e777be0SXin Li     FORWARD_LAYER_SUPPORT_FUNC(__func__,
1137*3e777be0SXin Li                                IsSplitterSupported,
1138*3e777be0SXin Li                                data.m_Backends,
1139*3e777be0SXin Li                                isSupported,
1140*3e777be0SXin Li                                setBackendSplit,
1141*3e777be0SXin Li                                inputInfo,
1142*3e777be0SXin Li                                splitterOutputInfos,
1143*3e777be0SXin Li                                splitterDesc);
1144*3e777be0SXin Li     if (!isSupported)
1145*3e777be0SXin Li     {
1146*3e777be0SXin Li         return false;
1147*3e777be0SXin Li     }
1148*3e777be0SXin Li 
1149*3e777be0SXin Li     IConnectableLayer* splitterLayer = data.m_Network->AddSplitterLayer(splitterDesc);
1150*3e777be0SXin Li     splitterLayer->SetBackendId(setBackendSplit);
1151*3e777be0SXin Li     if (!splitterLayer)
1152*3e777be0SXin Li     {
1153*3e777be0SXin Li         return Fail("%s: Failed to add SplitterLayer", __func__);
1154*3e777be0SXin Li     }
1155*3e777be0SXin Li 
1156*3e777be0SXin Li     input.Connect(splitterLayer->GetInputSlot(0));
1157*3e777be0SXin Li     for (unsigned int group = 0u; group < splitterLayer->GetNumOutputSlots(); ++group)
1158*3e777be0SXin Li     {
1159*3e777be0SXin Li         splitterLayer->GetOutputSlot(group).SetTensorInfo(splitterOutputInfo);
1160*3e777be0SXin Li     }
1161*3e777be0SXin Li 
1162*3e777be0SXin Li     //
1163*3e777be0SXin Li     // Set up Convolution2d layers for each group
1164*3e777be0SXin Li     //
1165*3e777be0SXin Li 
1166*3e777be0SXin Li     // Set up group tensor shapes
1167*3e777be0SXin Li     TensorShape groupInputShape(inputShape);
1168*3e777be0SXin Li     groupInputShape[channelsIndex] = channelsPerGroup;
1169*3e777be0SXin Li 
1170*3e777be0SXin Li     TensorShape groupWeightsShape(weightsShape);
1171*3e777be0SXin Li     groupWeightsShape[0] /= channelMultiplier * numGroups;
1172*3e777be0SXin Li 
1173*3e777be0SXin Li     TensorShape groupBiasesShape({ 1 });
1174*3e777be0SXin Li 
1175*3e777be0SXin Li     // Set up group tensor infos
1176*3e777be0SXin Li     TensorInfo groupInputInfo(inputInfo);
1177*3e777be0SXin Li     groupInputInfo.SetShape(groupInputShape);
1178*3e777be0SXin Li 
1179*3e777be0SXin Li     const TensorInfo& weightsInfo = weights.GetInfo();
1180*3e777be0SXin Li     TensorInfo groupWeightsInfo(weightsInfo);
1181*3e777be0SXin Li     groupWeightsInfo.SetShape(groupWeightsShape);
1182*3e777be0SXin Li 
1183*3e777be0SXin Li     const TensorInfo& biasesInfo = biases.GetInfo();
1184*3e777be0SXin Li     TensorInfo groupBiasesInfo(biasesInfo);
1185*3e777be0SXin Li     groupBiasesInfo.SetShape(groupBiasesShape);
1186*3e777be0SXin Li 
1187*3e777be0SXin Li     TensorInfo groupOutputInfo(outputInfo);
1188*3e777be0SXin Li 
1189*3e777be0SXin Li     TensorShape groupOutputShape(outputShape);
1190*3e777be0SXin Li     const bool isDynamic = IsDynamicTensor(outputInfo);
1191*3e777be0SXin Li     if (!isDynamic)
1192*3e777be0SXin Li     {
1193*3e777be0SXin Li         groupOutputShape[channelsIndex] = 1;
1194*3e777be0SXin Li     }
1195*3e777be0SXin Li     groupOutputInfo.SetShape(groupOutputShape);
1196*3e777be0SXin Li 
1197*3e777be0SXin Li     const unsigned int weightsDataTypeSize = GetDataTypeSize(groupWeightsInfo.GetDataType());
1198*3e777be0SXin Li     const unsigned int biasesDataTypeSize  = GetDataTypeSize(groupBiasesInfo.GetDataType());
1199*3e777be0SXin Li 
1200*3e777be0SXin Li     std::vector<IConnectableLayer*> convLayers(numGroups * channelMultiplier, nullptr);
1201*3e777be0SXin Li     for (unsigned int group = 0u; group < numGroups; ++group)
1202*3e777be0SXin Li     {
1203*3e777be0SXin Li         for (unsigned int m = 0u; m < channelMultiplier; ++m)
1204*3e777be0SXin Li         {
1205*3e777be0SXin Li             auto index = group * channelMultiplier + m;
1206*3e777be0SXin Li 
1207*3e777be0SXin Li             const unsigned int weightsDataOffset = groupWeightsShape.GetNumElements() * index * weightsDataTypeSize;
1208*3e777be0SXin Li             const unsigned int biasesDataOffset = groupBiasesShape.GetNumElements() * index * biasesDataTypeSize;
1209*3e777be0SXin Li 
1210*3e777be0SXin Li             if (weightsInfo.HasPerAxisQuantization())
1211*3e777be0SXin Li             {
1212*3e777be0SXin Li                 // Extract per-axis quantization scales for group weights
1213*3e777be0SXin Li                 const std::vector<float>& weightsQuantScales = weightsInfo.GetQuantizationScales();
1214*3e777be0SXin Li                 groupWeightsInfo.SetQuantizationScales(
1215*3e777be0SXin Li                     std::vector<float>(weightsQuantScales.begin() + index,
1216*3e777be0SXin Li                                        weightsQuantScales.begin() + index + groupWeightsShape[0]));
1217*3e777be0SXin Li 
1218*3e777be0SXin Li                 // Extract per-axis quantization scales for group biases
1219*3e777be0SXin Li                 const std::vector<float>& biasesQuantScales  = biasesInfo.GetQuantizationScales();
1220*3e777be0SXin Li                 groupBiasesInfo.SetQuantizationScales(
1221*3e777be0SXin Li                     std::vector<float>(biasesQuantScales.begin() + index,
1222*3e777be0SXin Li                                        biasesQuantScales.begin() + index + groupWeightsShape[0]));
1223*3e777be0SXin Li             }
1224*3e777be0SXin Li 
1225*3e777be0SXin Li             // Extract weights and biases data for current group convolution
1226*3e777be0SXin Li             ConstTensor groupWeights(groupWeightsInfo,
1227*3e777be0SXin Li                                      static_cast<const void *>(reinterpret_cast<const char *>(weights.GetMemoryArea()) +
1228*3e777be0SXin Li                                                                weightsDataOffset));
1229*3e777be0SXin Li             ConstTensor groupBiases(groupBiasesInfo,
1230*3e777be0SXin Li                                     static_cast<const void *>(reinterpret_cast<const char *>(biases.GetMemoryArea()) +
1231*3e777be0SXin Li                                                               biasesDataOffset));
1232*3e777be0SXin Li 
1233*3e777be0SXin Li             isSupported = false;
1234*3e777be0SXin Li             armnn::BackendId setBackendConv;
1235*3e777be0SXin Li             auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1236*3e777be0SXin Li             {
1237*3e777be0SXin Li                 FORWARD_LAYER_SUPPORT_FUNC(__func__,
1238*3e777be0SXin Li                                            IsConvolution2dSupported,
1239*3e777be0SXin Li                                            data.m_Backends,
1240*3e777be0SXin Li                                            isSupported,
1241*3e777be0SXin Li                                            setBackendConv,
1242*3e777be0SXin Li                                            groupInputInfo,
1243*3e777be0SXin Li                                            outputInfo,
1244*3e777be0SXin Li                                            desc,
1245*3e777be0SXin Li                                            groupWeightsInfo,
1246*3e777be0SXin Li                                            Optional<TensorInfo>(groupBiasesInfo));
1247*3e777be0SXin Li             };
1248*3e777be0SXin Li 
1249*3e777be0SXin Li             if(!isDynamic)
1250*3e777be0SXin Li             {
1251*3e777be0SXin Li                 validateFunc(groupOutputInfo, isSupported);
1252*3e777be0SXin Li             }
1253*3e777be0SXin Li             else
1254*3e777be0SXin Li             {
1255*3e777be0SXin Li                 isSupported = AreDynamicTensorsSupported();
1256*3e777be0SXin Li             }
1257*3e777be0SXin Li 
1258*3e777be0SXin Li             if (!isSupported)
1259*3e777be0SXin Li             {
1260*3e777be0SXin Li                 return false;
1261*3e777be0SXin Li             }
1262*3e777be0SXin Li 
1263*3e777be0SXin Li             IConnectableLayer* weightsLayer = data.m_Network->AddConstantLayer(groupWeights);
1264*3e777be0SXin Li             IConnectableLayer* biasLayer = data.m_Network->AddConstantLayer(groupBiases);
1265*3e777be0SXin Li             IConnectableLayer* convLayer = data.m_Network->AddConvolution2dLayer(desc);
1266*3e777be0SXin Li             convLayer->SetBackendId(setBackendConv);
1267*3e777be0SXin Li 
1268*3e777be0SXin Li             if (!convLayer)
1269*3e777be0SXin Li             {
1270*3e777be0SXin Li                 return Fail("%s: AddConvolution2dLayer failed", __func__);
1271*3e777be0SXin Li             }
1272*3e777be0SXin Li 
1273*3e777be0SXin Li             splitterLayer->GetOutputSlot(group).Connect(convLayer->GetInputSlot(0));
1274*3e777be0SXin Li             weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
1275*3e777be0SXin Li             biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
1276*3e777be0SXin Li 
1277*3e777be0SXin Li             weightsLayer->GetOutputSlot(0).SetTensorInfo(groupWeightsInfo);
1278*3e777be0SXin Li             biasLayer->GetOutputSlot(0).SetTensorInfo(groupBiasesInfo);
1279*3e777be0SXin Li             convLayer->GetOutputSlot(0).SetTensorInfo(groupOutputInfo);
1280*3e777be0SXin Li 
1281*3e777be0SXin Li             if(isDynamic)
1282*3e777be0SXin Li             {
1283*3e777be0SXin Li                 convLayer->GetOutputSlot(0).IsTensorInfoSet();
1284*3e777be0SXin Li 
1285*3e777be0SXin Li                 validateFunc(convLayer->GetOutputSlot(0).GetTensorInfo(), isSupported);
1286*3e777be0SXin Li 
1287*3e777be0SXin Li                 outputInfo = convLayer->GetOutputSlot(0).GetTensorInfo();
1288*3e777be0SXin Li 
1289*3e777be0SXin Li                 if (!isSupported)
1290*3e777be0SXin Li                 {
1291*3e777be0SXin Li                     return false;
1292*3e777be0SXin Li                 }
1293*3e777be0SXin Li             }
1294*3e777be0SXin Li 
1295*3e777be0SXin Li             convLayers[index] = convLayer;
1296*3e777be0SXin Li         }
1297*3e777be0SXin Li     }
1298*3e777be0SXin Li 
1299*3e777be0SXin Li     //
1300*3e777be0SXin Li     // Set up Concat layer
1301*3e777be0SXin Li     //
1302*3e777be0SXin Li     ConcatDescriptor concatDescriptor;
1303*3e777be0SXin Li     // Equivalent to outputShape[channelsIndex], but we can't know the outputShape in the case of dynamic tensors
1304*3e777be0SXin Li     concatDescriptor = ConcatDescriptor(weightsShape[0]);
1305*3e777be0SXin Li     for (unsigned int group = 0u; group < numGroups; ++group)
1306*3e777be0SXin Li     {
1307*3e777be0SXin Li         for (unsigned int m = 0u; m < channelMultiplier; ++m)
1308*3e777be0SXin Li         {
1309*3e777be0SXin Li             auto index = group * channelMultiplier + m;
1310*3e777be0SXin Li             concatDescriptor.SetViewOriginCoord(index, channelsIndex, index);
1311*3e777be0SXin Li             concatDescriptor.SetConcatAxis(channelsIndex);
1312*3e777be0SXin Li         }
1313*3e777be0SXin Li     }
1314*3e777be0SXin Li 
1315*3e777be0SXin Li     isSupported = false;
1316*3e777be0SXin Li     armnn::BackendId setBackendConcat;
1317*3e777be0SXin Li     FORWARD_LAYER_SUPPORT_FUNC(__func__,
1318*3e777be0SXin Li                                IsConcatSupported,
1319*3e777be0SXin Li                                data.m_Backends,
1320*3e777be0SXin Li                                isSupported,
1321*3e777be0SXin Li                                setBackendConcat,
1322*3e777be0SXin Li                                std::vector<const TensorInfo*>(numGroups * channelMultiplier, &groupOutputInfo),
1323*3e777be0SXin Li                                outputInfo,
1324*3e777be0SXin Li                                concatDescriptor);
1325*3e777be0SXin Li 
1326*3e777be0SXin Li     if (!isSupported)
1327*3e777be0SXin Li     {
1328*3e777be0SXin Li         return false;
1329*3e777be0SXin Li     }
1330*3e777be0SXin Li 
1331*3e777be0SXin Li     IConnectableLayer* concatLayer = data.m_Network->AddConcatLayer(concatDescriptor);
1332*3e777be0SXin Li     concatLayer->SetBackendId(setBackendConcat);
1333*3e777be0SXin Li     if (!concatLayer)
1334*3e777be0SXin Li     {
1335*3e777be0SXin Li         return Fail("%s: AddConcatLayer failed", __func__);
1336*3e777be0SXin Li     }
1337*3e777be0SXin Li 
1338*3e777be0SXin Li     for (unsigned int group = 0u; group < numGroups; ++group)
1339*3e777be0SXin Li     {
1340*3e777be0SXin Li         for (unsigned int m = 0u; m < channelMultiplier; ++m)
1341*3e777be0SXin Li         {
1342*3e777be0SXin Li             auto index = group * channelMultiplier + m;
1343*3e777be0SXin Li             convLayers[index]->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(index));
1344*3e777be0SXin Li         }
1345*3e777be0SXin Li     }
1346*3e777be0SXin Li     concatLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1347*3e777be0SXin Li 
1348*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *concatLayer, model,
1349*3e777be0SXin Li                                                    data, nullptr, nullptr, activation);
1350*3e777be0SXin Li }
1351*3e777be0SXin Li 
1352*3e777be0SXin Li template<typename HalPolicy,
1353*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
1354*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertInstanceNormalization(const HalOperation & operation,const HalModel & model,ConversionData & data)1355*3e777be0SXin Li bool ConvertInstanceNormalization(const HalOperation& operation, const HalModel& model, ConversionData& data)
1356*3e777be0SXin Li {
1357*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
1358*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
1359*3e777be0SXin Li 
1360*3e777be0SXin Li     ALOGV("HalPolicy::ConvertInstanceNormalization()");
1361*3e777be0SXin Li 
1362*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
1363*3e777be0SXin Li     if (!input.IsValid())
1364*3e777be0SXin Li     {
1365*3e777be0SXin Li         return Fail("%s: Operation has an invalid input 0", __func__);
1366*3e777be0SXin Li     }
1367*3e777be0SXin Li 
1368*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
1369*3e777be0SXin Li     if (!output)
1370*3e777be0SXin Li     {
1371*3e777be0SXin Li         return Fail("%s: Operation has an invalid output", __func__);
1372*3e777be0SXin Li     }
1373*3e777be0SXin Li 
1374*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
1375*3e777be0SXin Li 
1376*3e777be0SXin Li     // Determine data type of input tensor
1377*3e777be0SXin Li     HalOperandType inputType;
1378*3e777be0SXin Li     if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
1379*3e777be0SXin Li     {
1380*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
1381*3e777be0SXin Li     }
1382*3e777be0SXin Li 
1383*3e777be0SXin Li     InstanceNormalizationDescriptor desc;
1384*3e777be0SXin Li 
1385*3e777be0SXin Li     // Read gamma, beta & epsilon
1386*3e777be0SXin Li     if (inputType == HalOperandType::TENSOR_FLOAT16)
1387*3e777be0SXin Li     {
1388*3e777be0SXin Li         Half fp16Gamma;
1389*3e777be0SXin Li         Half fp16Beta;
1390*3e777be0SXin Li         Half fp16Epsilon;
1391*3e777be0SXin Li 
1392*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, fp16Gamma, model, data) ||
1393*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::FLOAT16, fp16Beta, model, data) ||
1394*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 3, HalOperandType::FLOAT16, fp16Epsilon, model, data))
1395*3e777be0SXin Li         {
1396*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
1397*3e777be0SXin Li         }
1398*3e777be0SXin Li 
1399*3e777be0SXin Li         desc.m_Gamma = static_cast<float>(fp16Gamma);
1400*3e777be0SXin Li         desc.m_Beta  = static_cast<float>(fp16Beta);
1401*3e777be0SXin Li         desc.m_Eps   = static_cast<float>(fp16Epsilon);
1402*3e777be0SXin Li     }
1403*3e777be0SXin Li     else if (inputType == HalOperandType::TENSOR_FLOAT32)
1404*3e777be0SXin Li     {
1405*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_Gamma, model, data) ||
1406*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::FLOAT32, desc.m_Beta, model, data) ||
1407*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 3, HalOperandType::FLOAT32, desc.m_Eps, model, data))
1408*3e777be0SXin Li         {
1409*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
1410*3e777be0SXin Li         }
1411*3e777be0SXin Li     }
1412*3e777be0SXin Li     else
1413*3e777be0SXin Li     {
1414*3e777be0SXin Li         return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
1415*3e777be0SXin Li     }
1416*3e777be0SXin Li 
1417*3e777be0SXin Li     desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 4, model, data);
1418*3e777be0SXin Li 
1419*3e777be0SXin Li     bool isSupported = false;
1420*3e777be0SXin Li     armnn::BackendId setBackend;
1421*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1422*3e777be0SXin Li     {
1423*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
1424*3e777be0SXin Li                                    IsInstanceNormalizationSupported,
1425*3e777be0SXin Li                                    data.m_Backends,
1426*3e777be0SXin Li                                    isSupported,
1427*3e777be0SXin Li                                    setBackend,
1428*3e777be0SXin Li                                    input.GetTensorInfo(),
1429*3e777be0SXin Li                                    outputInfo,
1430*3e777be0SXin Li                                    desc);
1431*3e777be0SXin Li     };
1432*3e777be0SXin Li 
1433*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
1434*3e777be0SXin Li     {
1435*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
1436*3e777be0SXin Li     }
1437*3e777be0SXin Li     else
1438*3e777be0SXin Li     {
1439*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
1440*3e777be0SXin Li     }
1441*3e777be0SXin Li 
1442*3e777be0SXin Li     if (!isSupported)
1443*3e777be0SXin Li     {
1444*3e777be0SXin Li         return false;
1445*3e777be0SXin Li     }
1446*3e777be0SXin Li 
1447*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddInstanceNormalizationLayer(desc);
1448*3e777be0SXin Li     layer->SetBackendId(setBackend);
1449*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
1450*3e777be0SXin Li 
1451*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
1452*3e777be0SXin Li }
1453*3e777be0SXin Li 
1454*3e777be0SXin Li template<typename HalPolicy,
1455*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
1456*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertLogSoftmax(const HalOperation & operation,const HalModel & model,ConversionData & data)1457*3e777be0SXin Li bool ConvertLogSoftmax(const HalOperation& operation, const HalModel& model, ConversionData& data)
1458*3e777be0SXin Li {
1459*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
1460*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
1461*3e777be0SXin Li 
1462*3e777be0SXin Li     ALOGV("HalPolicy::ConvertLogSoftmax()");
1463*3e777be0SXin Li 
1464*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
1465*3e777be0SXin Li     if (!input.IsValid())
1466*3e777be0SXin Li     {
1467*3e777be0SXin Li         return Fail("%s: Failed to read input 0", __func__);
1468*3e777be0SXin Li     }
1469*3e777be0SXin Li 
1470*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
1471*3e777be0SXin Li     if (!output)
1472*3e777be0SXin Li     {
1473*3e777be0SXin Li         return Fail("%s: Failed to read output", __func__);
1474*3e777be0SXin Li     }
1475*3e777be0SXin Li 
1476*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
1477*3e777be0SXin Li 
1478*3e777be0SXin Li     // Determine data type of input tensor
1479*3e777be0SXin Li     HalOperandType inputType;
1480*3e777be0SXin Li     if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
1481*3e777be0SXin Li     {
1482*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
1483*3e777be0SXin Li     }
1484*3e777be0SXin Li 
1485*3e777be0SXin Li     LogSoftmaxDescriptor descriptor;
1486*3e777be0SXin Li 
1487*3e777be0SXin Li     // Read beta
1488*3e777be0SXin Li     if (inputType == HalOperandType::TENSOR_FLOAT16)
1489*3e777be0SXin Li     {
1490*3e777be0SXin Li         Half fp16Beta;
1491*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, fp16Beta, model, data))
1492*3e777be0SXin Li         {
1493*3e777be0SXin Li             return Fail("%s: Failed to read input 1 (FLOAT16)", __func__);
1494*3e777be0SXin Li         }
1495*3e777be0SXin Li 
1496*3e777be0SXin Li         descriptor.m_Beta  = static_cast<float>(fp16Beta);
1497*3e777be0SXin Li     }
1498*3e777be0SXin Li     else if (inputType == HalOperandType::TENSOR_FLOAT32)
1499*3e777be0SXin Li     {
1500*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, descriptor.m_Beta, model, data))
1501*3e777be0SXin Li         {
1502*3e777be0SXin Li             return Fail("%s: Failed to read input 1 (FLOAT32)", __func__);
1503*3e777be0SXin Li         }
1504*3e777be0SXin Li     }
1505*3e777be0SXin Li     else
1506*3e777be0SXin Li     {
1507*3e777be0SXin Li         return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
1508*3e777be0SXin Li     }
1509*3e777be0SXin Li 
1510*3e777be0SXin Li     // Read axis
1511*3e777be0SXin Li     if (!GetInputInt32<HalPolicy>(operation, 2, descriptor.m_Axis, model, data))
1512*3e777be0SXin Li     {
1513*3e777be0SXin Li         return Fail("%s: Failed to read input 2", __func__);
1514*3e777be0SXin Li     }
1515*3e777be0SXin Li 
1516*3e777be0SXin Li     bool isSupported = false;
1517*3e777be0SXin Li     armnn::BackendId setBackend;
1518*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1519*3e777be0SXin Li     {
1520*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
1521*3e777be0SXin Li                                    IsLogSoftmaxSupported,
1522*3e777be0SXin Li                                    data.m_Backends,
1523*3e777be0SXin Li                                    isSupported,
1524*3e777be0SXin Li                                    setBackend,
1525*3e777be0SXin Li                                    input.GetTensorInfo(),
1526*3e777be0SXin Li                                    outputInfo,
1527*3e777be0SXin Li                                    descriptor);
1528*3e777be0SXin Li     };
1529*3e777be0SXin Li 
1530*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
1531*3e777be0SXin Li     {
1532*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
1533*3e777be0SXin Li     }
1534*3e777be0SXin Li     else
1535*3e777be0SXin Li     {
1536*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
1537*3e777be0SXin Li     }
1538*3e777be0SXin Li 
1539*3e777be0SXin Li     if (!isSupported)
1540*3e777be0SXin Li     {
1541*3e777be0SXin Li         return false;
1542*3e777be0SXin Li     }
1543*3e777be0SXin Li 
1544*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddLogSoftmaxLayer(descriptor);
1545*3e777be0SXin Li     layer->SetBackendId(setBackend);
1546*3e777be0SXin Li     if (!layer)
1547*3e777be0SXin Li     {
1548*3e777be0SXin Li         return Fail("%s: Could not add the LogSoftmaxLayer", __func__);
1549*3e777be0SXin Li     }
1550*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
1551*3e777be0SXin Li 
1552*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
1553*3e777be0SXin Li }
1554*3e777be0SXin Li 
1555*3e777be0SXin Li template<typename HalPolicy,
1556*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
1557*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertPadV2(const HalOperation & operation,const HalModel & model,ConversionData & data)1558*3e777be0SXin Li bool ConvertPadV2(const HalOperation& operation, const HalModel& model, ConversionData& data)
1559*3e777be0SXin Li {
1560*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
1561*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
1562*3e777be0SXin Li 
1563*3e777be0SXin Li     ALOGV("HalPolicy::ConvertPadV2()");
1564*3e777be0SXin Li 
1565*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
1566*3e777be0SXin Li     if (!input.IsValid())
1567*3e777be0SXin Li     {
1568*3e777be0SXin Li         return Fail("%s: Could not read input 0", __func__);
1569*3e777be0SXin Li     }
1570*3e777be0SXin Li 
1571*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
1572*3e777be0SXin Li     if (!output)
1573*3e777be0SXin Li     {
1574*3e777be0SXin Li         return Fail("%s: Could not read output", __func__);
1575*3e777be0SXin Li     }
1576*3e777be0SXin Li 
1577*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
1578*3e777be0SXin Li     unsigned int rank = inputInfo.GetNumDimensions();
1579*3e777be0SXin Li 
1580*3e777be0SXin Li     PadDescriptor descriptor;
1581*3e777be0SXin Li     if (!ConvertPaddings<HalPolicy>(operation, model, data, rank, descriptor))
1582*3e777be0SXin Li     {
1583*3e777be0SXin Li         return Fail("%s: Could not convert paddings", __func__);
1584*3e777be0SXin Li     }
1585*3e777be0SXin Li 
1586*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
1587*3e777be0SXin Li 
1588*3e777be0SXin Li     // Determine type of padding value
1589*3e777be0SXin Li     HalOperandType operandType0;
1590*3e777be0SXin Li     HalOperandType operandType2;
1591*3e777be0SXin Li 
1592*3e777be0SXin Li     if (!GetOperandType<HalPolicy>(operation, 0, model, operandType0) ||
1593*3e777be0SXin Li         !GetOperandType<HalPolicy>(operation, 2, model, operandType2))
1594*3e777be0SXin Li     {
1595*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
1596*3e777be0SXin Li     }
1597*3e777be0SXin Li 
1598*3e777be0SXin Li     // Read value to use for padding
1599*3e777be0SXin Li     if (operandType0 == HalOperandType::TENSOR_FLOAT16 && operandType2 == HalOperandType::FLOAT16)
1600*3e777be0SXin Li     {
1601*3e777be0SXin Li         Half f16PadValue;
1602*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 2, operandType2, f16PadValue, model, data))
1603*3e777be0SXin Li         {
1604*3e777be0SXin Li             return Fail("%s: Could not read input 2 (FLOAT16)", __func__);
1605*3e777be0SXin Li         }
1606*3e777be0SXin Li 
1607*3e777be0SXin Li         descriptor.m_PadValue = f16PadValue;
1608*3e777be0SXin Li     }
1609*3e777be0SXin Li     else if (operandType0 == HalOperandType::TENSOR_FLOAT32 && operandType2 == HalOperandType::FLOAT32)
1610*3e777be0SXin Li     {
1611*3e777be0SXin Li         if (!GetInputFloat32<HalPolicy>(operation, 2, descriptor.m_PadValue, model, data))
1612*3e777be0SXin Li         {
1613*3e777be0SXin Li             return Fail("%s: Could not read input 2 (FLOAT32)", __func__);
1614*3e777be0SXin Li         }
1615*3e777be0SXin Li     }
1616*3e777be0SXin Li     else if (isQuantizedOperand(operandType0) && operandType2 == HalOperandType::INT32)
1617*3e777be0SXin Li     {
1618*3e777be0SXin Li         int32_t intPadValue = 0;
1619*3e777be0SXin Li         if (!GetInputInt32<HalPolicy>(operation, 2, intPadValue, model, data))
1620*3e777be0SXin Li         {
1621*3e777be0SXin Li             return Fail("%s: Could not read input 2 (INT32)", __func__);
1622*3e777be0SXin Li         }
1623*3e777be0SXin Li         descriptor.m_PadValue = intPadValue;
1624*3e777be0SXin Li     }
1625*3e777be0SXin Li     else
1626*3e777be0SXin Li     {
1627*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs: type mismatch", __func__);
1628*3e777be0SXin Li     }
1629*3e777be0SXin Li 
1630*3e777be0SXin Li     bool isSupported = false;
1631*3e777be0SXin Li     armnn::BackendId setBackend;
1632*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1633*3e777be0SXin Li     {
1634*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
1635*3e777be0SXin Li                                    IsPadSupported,
1636*3e777be0SXin Li                                    data.m_Backends,
1637*3e777be0SXin Li                                    isSupported,
1638*3e777be0SXin Li                                    setBackend,
1639*3e777be0SXin Li                                    inputInfo,
1640*3e777be0SXin Li                                    outputInfo,
1641*3e777be0SXin Li                                    descriptor);
1642*3e777be0SXin Li     };
1643*3e777be0SXin Li 
1644*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
1645*3e777be0SXin Li     {
1646*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
1647*3e777be0SXin Li     }
1648*3e777be0SXin Li     else
1649*3e777be0SXin Li     {
1650*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
1651*3e777be0SXin Li     }
1652*3e777be0SXin Li 
1653*3e777be0SXin Li     if (!isSupported)
1654*3e777be0SXin Li     {
1655*3e777be0SXin Li         return false;
1656*3e777be0SXin Li     }
1657*3e777be0SXin Li 
1658*3e777be0SXin Li     IConnectableLayer* const layer = data.m_Network->AddPadLayer(descriptor);
1659*3e777be0SXin Li     layer->SetBackendId(setBackend);
1660*3e777be0SXin Li     if (!layer)
1661*3e777be0SXin Li     {
1662*3e777be0SXin Li         return Fail("%s: Could not add the PadLayer", __func__);
1663*3e777be0SXin Li     }
1664*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
1665*3e777be0SXin Li 
1666*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
1667*3e777be0SXin Li }
1668*3e777be0SXin Li 
1669*3e777be0SXin Li template<typename HalPolicy,
1670*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
1671*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertPrelu(const HalOperation & operation,const HalModel & model,ConversionData & data)1672*3e777be0SXin Li bool ConvertPrelu(const HalOperation& operation, const HalModel& model, ConversionData& data)
1673*3e777be0SXin Li {
1674*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
1675*3e777be0SXin Li 
1676*3e777be0SXin Li     ALOGV("HalPolicy::ConvertPrelu()");
1677*3e777be0SXin Li 
1678*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
1679*3e777be0SXin Li     LayerInputHandle alpha = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
1680*3e777be0SXin Li 
1681*3e777be0SXin Li     if (!input.IsValid() || !alpha.IsValid())
1682*3e777be0SXin Li     {
1683*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
1684*3e777be0SXin Li     }
1685*3e777be0SXin Li 
1686*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
1687*3e777be0SXin Li 
1688*3e777be0SXin Li     if (!output)
1689*3e777be0SXin Li     {
1690*3e777be0SXin Li         return Fail("%s: Could not read output", __func__);
1691*3e777be0SXin Li     }
1692*3e777be0SXin Li 
1693*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
1694*3e777be0SXin Li     const TensorInfo& alphaInfo  = alpha.GetTensorInfo();
1695*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
1696*3e777be0SXin Li 
1697*3e777be0SXin Li     bool isSupported = false;
1698*3e777be0SXin Li     armnn::BackendId setBackend;
1699*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1700*3e777be0SXin Li     {
1701*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
1702*3e777be0SXin Li                                    IsPreluSupported,
1703*3e777be0SXin Li                                    data.m_Backends,
1704*3e777be0SXin Li                                    isSupported,
1705*3e777be0SXin Li                                    setBackend,
1706*3e777be0SXin Li                                    inputInfo,
1707*3e777be0SXin Li                                    alphaInfo,
1708*3e777be0SXin Li                                    outputInfo);
1709*3e777be0SXin Li     };
1710*3e777be0SXin Li 
1711*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
1712*3e777be0SXin Li     {
1713*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
1714*3e777be0SXin Li     }
1715*3e777be0SXin Li     else
1716*3e777be0SXin Li     {
1717*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
1718*3e777be0SXin Li     }
1719*3e777be0SXin Li 
1720*3e777be0SXin Li     if (!isSupported)
1721*3e777be0SXin Li     {
1722*3e777be0SXin Li         return false;
1723*3e777be0SXin Li     }
1724*3e777be0SXin Li 
1725*3e777be0SXin Li     IConnectableLayer* const layer = data.m_Network->AddPreluLayer();
1726*3e777be0SXin Li     layer->SetBackendId(setBackend);
1727*3e777be0SXin Li     if (!layer)
1728*3e777be0SXin Li     {
1729*3e777be0SXin Li         return Fail("%s: Could not add the PreluLayer", __func__);
1730*3e777be0SXin Li     }
1731*3e777be0SXin Li 
1732*3e777be0SXin Li     bool isReshapeSupported = BroadcastTensor(input, alpha, layer, data);
1733*3e777be0SXin Li     if (!isReshapeSupported)
1734*3e777be0SXin Li     {
1735*3e777be0SXin Li         return false;
1736*3e777be0SXin Li     }
1737*3e777be0SXin Li 
1738*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
1739*3e777be0SXin Li }
1740*3e777be0SXin Li 
1741*3e777be0SXin Li template<typename HalPolicy,
1742*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
1743*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertQuantize(const HalOperation & operation,const HalModel & model,ConversionData & data)1744*3e777be0SXin Li bool ConvertQuantize(const HalOperation& operation, const HalModel& model, ConversionData& data)
1745*3e777be0SXin Li {
1746*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
1747*3e777be0SXin Li 
1748*3e777be0SXin Li     ALOGV("HalPolicy::ConvertQuantize()");
1749*3e777be0SXin Li 
1750*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
1751*3e777be0SXin Li     if (!input.IsValid())
1752*3e777be0SXin Li     {
1753*3e777be0SXin Li         return Fail("%s: Operation has invalid input", __func__);
1754*3e777be0SXin Li     }
1755*3e777be0SXin Li 
1756*3e777be0SXin Li     const HalOperand* const outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
1757*3e777be0SXin Li     if (!outputOperand)
1758*3e777be0SXin Li     {
1759*3e777be0SXin Li         return Fail("%s: Operation has invalid outputs", __func__);
1760*3e777be0SXin Li     }
1761*3e777be0SXin Li 
1762*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*outputOperand);
1763*3e777be0SXin Li 
1764*3e777be0SXin Li     bool isSupported = false;
1765*3e777be0SXin Li     armnn::BackendId setBackend;
1766*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1767*3e777be0SXin Li     {
1768*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
1769*3e777be0SXin Li                                    IsQuantizeSupported,
1770*3e777be0SXin Li                                    data.m_Backends,
1771*3e777be0SXin Li                                    isSupported,
1772*3e777be0SXin Li                                    setBackend,
1773*3e777be0SXin Li                                    input.GetTensorInfo(),
1774*3e777be0SXin Li                                    outputInfo);
1775*3e777be0SXin Li     };
1776*3e777be0SXin Li 
1777*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
1778*3e777be0SXin Li     {
1779*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
1780*3e777be0SXin Li     }
1781*3e777be0SXin Li     else
1782*3e777be0SXin Li     {
1783*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
1784*3e777be0SXin Li     }
1785*3e777be0SXin Li 
1786*3e777be0SXin Li     if (!isSupported)
1787*3e777be0SXin Li     {
1788*3e777be0SXin Li         return false;
1789*3e777be0SXin Li     }
1790*3e777be0SXin Li 
1791*3e777be0SXin Li     IConnectableLayer* const layer = data.m_Network->AddQuantizeLayer();
1792*3e777be0SXin Li     layer->SetBackendId(setBackend);
1793*3e777be0SXin Li     if (!layer)
1794*3e777be0SXin Li     {
1795*3e777be0SXin Li         return Fail("%s: Could not add the QuantizeLayer", __func__);
1796*3e777be0SXin Li     }
1797*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
1798*3e777be0SXin Li 
1799*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
1800*3e777be0SXin Li }
1801*3e777be0SXin Li 
1802*3e777be0SXin Li template<typename HalPolicy,
1803*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
1804*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertQuantized16BitLstm(const HalOperation & operation,const HalModel & model,ConversionData & data)1805*3e777be0SXin Li bool ConvertQuantized16BitLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
1806*3e777be0SXin Li {
1807*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
1808*3e777be0SXin Li 
1809*3e777be0SXin Li     ALOGV("HalPolicy::ConvertQuantized16BitLstm()");
1810*3e777be0SXin Li 
1811*3e777be0SXin Li     //Inputs:
1812*3e777be0SXin Li     // 0: The input: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBatches, inputSize]
1813*3e777be0SXin Li     //    specifying the input to the LSTM cell. Tensor is quantized with a fixed quantization range of -1, 127/128.
1814*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
1815*3e777be0SXin Li     if (!input.IsValid())
1816*3e777be0SXin Li     {
1817*3e777be0SXin Li         return Fail("%s: Could not read input 0: input", __func__);
1818*3e777be0SXin Li     }
1819*3e777be0SXin Li 
1820*3e777be0SXin Li     //13: The previous cell state: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT16_SYMM and shape
1821*3e777be0SXin Li     //    [numBatches, outputSize] specifying the cell state from the previous time step of the LSTM cell.
1822*3e777be0SXin Li     //    It is quantized using a quantization range of -2^4, 2^4 * 32767/32768.
1823*3e777be0SXin Li     LayerInputHandle previousCellStateIn = ConvertToLayerInputHandle<HalPolicy>(operation, 13, model, data);
1824*3e777be0SXin Li     if (!previousCellStateIn.IsValid())
1825*3e777be0SXin Li     {
1826*3e777be0SXin Li         return Fail("%s: Could not read input 13: previousCellStateIn", __func__);
1827*3e777be0SXin Li     }
1828*3e777be0SXin Li 
1829*3e777be0SXin Li     // 14: The previous output state: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1830*3e777be0SXin Li     //     [numBathes, outputSize] specifying the output of the LSTM cell from previous time-step. Tensor
1831*3e777be0SXin Li     //     is quantized with a fixed quantization range of -1, 127/128.
1832*3e777be0SXin Li     LayerInputHandle previousOutputIn = ConvertToLayerInputHandle<HalPolicy>(operation, 14, model, data);
1833*3e777be0SXin Li     if (!previousOutputIn.IsValid())
1834*3e777be0SXin Li     {
1835*3e777be0SXin Li         return Fail("%s: Could not read input 14: previousOutputIn", __func__);
1836*3e777be0SXin Li     }
1837*3e777be0SXin Li 
1838*3e777be0SXin Li     // Get the input tensors:
1839*3e777be0SXin Li     // 1: The input-to-input weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1840*3e777be0SXin Li     //    [outputSize, inputSize] specifying input-to-input part of weights for fully-connected layer inside the
1841*3e777be0SXin Li     //    LSTM cell. Quantization zero point and scale must be the same across all the weights.
1842*3e777be0SXin Li     const ConstTensorPin inputToInputWeightsPin =
1843*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1, model, data);
1844*3e777be0SXin Li 
1845*3e777be0SXin Li     // 2: The input-to-forget weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1846*3e777be0SXin Li     //    [outputSize, inputSize] specifying input-to-forget part of weights for fully-connected layer inside the
1847*3e777be0SXin Li     //    LSTM cell. Quantization zero point and scale must be the same across all the weights.
1848*3e777be0SXin Li     const ConstTensorPin inputToForgetWeightsPin =
1849*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
1850*3e777be0SXin Li 
1851*3e777be0SXin Li     // 3: The input-to-cell weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1852*3e777be0SXin Li     //    [outputSize, inputSize] specifying input-to-cell part of weights for fully-connected layer inside the
1853*3e777be0SXin Li     //    LSTM cell. Quantization zero point and scale must be the same across all the weights.
1854*3e777be0SXin Li     const ConstTensorPin inputToCellWeightsPin =
1855*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 3, model, data);
1856*3e777be0SXin Li 
1857*3e777be0SXin Li     // 4: The input-to-output weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1858*3e777be0SXin Li     //    [outputSize, inputSize] specifying input-to-output part of weights for fully-connected layer inside the
1859*3e777be0SXin Li     //    LSTM cell. Quantization zero point and scale must be the same across all the weights.
1860*3e777be0SXin Li     const ConstTensorPin inputToOutputWeightsPin =
1861*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 4, model, data);
1862*3e777be0SXin Li 
1863*3e777be0SXin Li     // 5: The recurrent-to-input weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1864*3e777be0SXin Li     //    [outputSize, outputSize] specifying recurrent-to-input part of weights for fully-connected layer inside
1865*3e777be0SXin Li     //    the LSTM cell. Quantization zero point and scale must be the same across all the weights.
1866*3e777be0SXin Li     const ConstTensorPin recurrentToInputWeightsPin =
1867*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 5, model, data);
1868*3e777be0SXin Li 
1869*3e777be0SXin Li     // 6: The recurrent-to-forget weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1870*3e777be0SXin Li     //    [outputSize, outputSize] specifying recurrent-to-forget part of weights for fully-connected layer inside
1871*3e777be0SXin Li     //    the LSTM cell. Quantization zero point and scale must be the same across all the weights.
1872*3e777be0SXin Li     const ConstTensorPin recurrentToForgetWeightsPin =
1873*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 6, model, data);
1874*3e777be0SXin Li 
1875*3e777be0SXin Li     // 7: The recurrent-to-cell weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1876*3e777be0SXin Li     //    [outputSize, outputSize] specifying recurrent-to-cell part of weights for fully-connected layer inside
1877*3e777be0SXin Li     //    the LSTM cell. Quantization zero point and scale must be the same across all the weights.
1878*3e777be0SXin Li     const ConstTensorPin recurrentToCellWeightsPin =
1879*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 7, model, data);
1880*3e777be0SXin Li 
1881*3e777be0SXin Li     // 8: The recurrent-to-output weights. A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape
1882*3e777be0SXin Li     //    [outputSize, outputSize] specifying recurrent-to-output part of weights for fully-connected layer inside
1883*3e777be0SXin Li     //    the LSTM cell. Quantization zero point and scale must be the same across all the weights.
1884*3e777be0SXin Li     const ConstTensorPin recurrentToOutputWeightsPin =
1885*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 8, model, data);
1886*3e777be0SXin Li 
1887*3e777be0SXin Li     // 9: The input gate bias. A 1-D tensor of type ANEURALNETWORKS_TENSOR_INT32 and shape [outputSize] specifying the
1888*3e777be0SXin Li     //    bias for the fully-connected layer inside the LSTM cell. Bias is quantized with scale being a product
1889*3e777be0SXin Li     //    of input and weights scales and zeroPoint equal to 0.
1890*3e777be0SXin Li     const ConstTensorPin inputGateBiasPin =
1891*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 9, model, data);
1892*3e777be0SXin Li 
1893*3e777be0SXin Li     // 10: The forget gate bias. A 1-D tensor of type ANEURALNETWORKS_TENSOR_INT32 and shape [outputSize] specifying
1894*3e777be0SXin Li     //     the bias for the fully-connected layer inside the LSTM cell. Bias is quantized with scale being a product
1895*3e777be0SXin Li     //     of input and weights scales and zeroPoint equal to 0.
1896*3e777be0SXin Li     const ConstTensorPin forgetGateBiasPin =
1897*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 10, model, data);
1898*3e777be0SXin Li 
1899*3e777be0SXin Li     // 11:The cell bias. A 1-D tensor of type ANEURALNETWORKS_TENSOR_INT32 and shape [outputSize] specifying the bias
1900*3e777be0SXin Li     //    for the fully-connected layer inside the LSTM cell. Bias is quantized with scale being a product of input
1901*3e777be0SXin Li     //    and weights scales and zeroPoint equal to 0.
1902*3e777be0SXin Li     const ConstTensorPin cellBiasPin =
1903*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 11, model, data);
1904*3e777be0SXin Li 
1905*3e777be0SXin Li     // 12:The output gate bias. A 1-D tensor of type ANEURALNETWORKS_TENSOR_INT32 and shape [outputSize] specifying
1906*3e777be0SXin Li     //    the bias for the fully-connected layer inside the LSTM cell. Bias is quantized with scale being a product
1907*3e777be0SXin Li     //    of input and weights scales and zeroPoint equal to 0.
1908*3e777be0SXin Li     const ConstTensorPin outputGateBiasPin =
1909*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 12, model, data);
1910*3e777be0SXin Li 
1911*3e777be0SXin Li     if (!inputToInputWeightsPin.IsValid() ||
1912*3e777be0SXin Li         !inputToForgetWeightsPin.IsValid() ||
1913*3e777be0SXin Li         !inputToCellWeightsPin.IsValid() ||
1914*3e777be0SXin Li         !inputToOutputWeightsPin.IsValid() ||
1915*3e777be0SXin Li         !recurrentToInputWeightsPin.IsValid() ||
1916*3e777be0SXin Li         !recurrentToForgetWeightsPin.IsValid() ||
1917*3e777be0SXin Li         !recurrentToCellWeightsPin.IsValid() ||
1918*3e777be0SXin Li         !recurrentToOutputWeightsPin.IsValid() ||
1919*3e777be0SXin Li         !inputGateBiasPin.IsValid() ||
1920*3e777be0SXin Li         !forgetGateBiasPin.IsValid() ||
1921*3e777be0SXin Li         !cellBiasPin.IsValid() ||
1922*3e777be0SXin Li         !outputGateBiasPin.IsValid())
1923*3e777be0SXin Li     {
1924*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
1925*3e777be0SXin Li     }
1926*3e777be0SXin Li 
1927*3e777be0SXin Li     // Outputs:
1928*3e777be0SXin Li     // 0: The cell state: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT16_SYMM and shape [numBatches, outputSize]
1929*3e777be0SXin Li     //    which contains a cell state from the current time step. Tensor is quantized using a quantization range
1930*3e777be0SXin Li     //    of -2^4, 2^4 * 32767/32768.
1931*3e777be0SXin Li     const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 0, model);
1932*3e777be0SXin Li     if (!cellStateOut)
1933*3e777be0SXin Li     {
1934*3e777be0SXin Li         return Fail("%s: Could not read output 0: cellStateOut", __func__);
1935*3e777be0SXin Li     }
1936*3e777be0SXin Li 
1937*3e777be0SXin Li     // 1: The output: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBathes, outputSize] which
1938*3e777be0SXin Li     //      contains the output value. Tensor is quantized with a fixed quantization range of -1, 127/128.
1939*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 1, model);
1940*3e777be0SXin Li     if (!output)
1941*3e777be0SXin Li     {
1942*3e777be0SXin Li         return Fail("%s: Could not read output 1: output", __func__);
1943*3e777be0SXin Li     }
1944*3e777be0SXin Li 
1945*3e777be0SXin Li     // Inputs
1946*3e777be0SXin Li     const TensorInfo& inputInfo               = input.GetTensorInfo();
1947*3e777be0SXin Li     const TensorInfo& previousCellStateInInfo = previousCellStateIn.GetTensorInfo();
1948*3e777be0SXin Li     const TensorInfo& previousOutputInInfo    = previousOutputIn.GetTensorInfo();
1949*3e777be0SXin Li 
1950*3e777be0SXin Li     // Outputs
1951*3e777be0SXin Li     const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
1952*3e777be0SXin Li     const TensorInfo& outputInfo       = GetTensorInfoForOperand(*output);
1953*3e777be0SXin Li 
1954*3e777be0SXin Li     // Dynamic tensors currently not supported
1955*3e777be0SXin Li     if (IsDynamicTensor(cellStateOutInfo) || IsDynamicTensor(outputInfo))
1956*3e777be0SXin Li     {
1957*3e777be0SXin Li         return Fail("%s: Dynamic output tensors are not supported", __func__);
1958*3e777be0SXin Li     }
1959*3e777be0SXin Li 
1960*3e777be0SXin Li     QuantizedLstmInputParams params;
1961*3e777be0SXin Li 
1962*3e777be0SXin Li     params.m_InputToInputWeights      = inputToInputWeightsPin.GetConstTensorPtr();
1963*3e777be0SXin Li     params.m_InputToForgetWeights     = inputToForgetWeightsPin.GetConstTensorPtr();
1964*3e777be0SXin Li     params.m_InputToCellWeights       = inputToCellWeightsPin.GetConstTensorPtr();
1965*3e777be0SXin Li     params.m_InputToOutputWeights     = inputToOutputWeightsPin.GetConstTensorPtr();
1966*3e777be0SXin Li     params.m_RecurrentToInputWeights  = recurrentToInputWeightsPin.GetConstTensorPtr();
1967*3e777be0SXin Li     params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
1968*3e777be0SXin Li     params.m_RecurrentToCellWeights   = recurrentToCellWeightsPin.GetConstTensorPtr();
1969*3e777be0SXin Li     params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
1970*3e777be0SXin Li     params.m_InputGateBias            = inputGateBiasPin.GetConstTensorPtr();
1971*3e777be0SXin Li     params.m_ForgetGateBias           = forgetGateBiasPin.GetConstTensorPtr();
1972*3e777be0SXin Li     params.m_CellBias                 = cellBiasPin.GetConstTensorPtr();
1973*3e777be0SXin Li     params.m_OutputGateBias           = outputGateBiasPin.GetConstTensorPtr();
1974*3e777be0SXin Li 
1975*3e777be0SXin Li     QuantizedLstmInputParamsInfo paramsInfo;
1976*3e777be0SXin Li     paramsInfo.m_InputToInputWeights      = &(params.m_InputToInputWeights->GetInfo());
1977*3e777be0SXin Li     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
1978*3e777be0SXin Li     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
1979*3e777be0SXin Li     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
1980*3e777be0SXin Li     paramsInfo.m_RecurrentToInputWeights  = &(params.m_RecurrentToInputWeights->GetInfo());
1981*3e777be0SXin Li     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
1982*3e777be0SXin Li     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
1983*3e777be0SXin Li     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
1984*3e777be0SXin Li     paramsInfo.m_InputGateBias            = &(params.m_InputGateBias->GetInfo());
1985*3e777be0SXin Li     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
1986*3e777be0SXin Li     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
1987*3e777be0SXin Li     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
1988*3e777be0SXin Li 
1989*3e777be0SXin Li     bool isSupported = false;
1990*3e777be0SXin Li     armnn::BackendId setBackend;
1991*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
1992*3e777be0SXin Li     {
1993*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
1994*3e777be0SXin Li                                    IsQuantizedLstmSupported,
1995*3e777be0SXin Li                                    data.m_Backends,
1996*3e777be0SXin Li                                    isSupported,
1997*3e777be0SXin Li                                    setBackend,
1998*3e777be0SXin Li                                    inputInfo,
1999*3e777be0SXin Li                                    previousCellStateInInfo,
2000*3e777be0SXin Li                                    previousOutputInInfo,
2001*3e777be0SXin Li                                    cellStateOutInfo,
2002*3e777be0SXin Li                                    outputInfo,
2003*3e777be0SXin Li                                    paramsInfo);
2004*3e777be0SXin Li     };
2005*3e777be0SXin Li 
2006*3e777be0SXin Li     bool isDynamic = false;
2007*3e777be0SXin Li     if (!IsDynamicTensor(cellStateOutInfo) &&
2008*3e777be0SXin Li         !IsDynamicTensor(outputInfo))
2009*3e777be0SXin Li     {
2010*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
2011*3e777be0SXin Li     }
2012*3e777be0SXin Li     else
2013*3e777be0SXin Li     {
2014*3e777be0SXin Li         isDynamic = true;
2015*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
2016*3e777be0SXin Li     }
2017*3e777be0SXin Li 
2018*3e777be0SXin Li     if (!isSupported)
2019*3e777be0SXin Li     {
2020*3e777be0SXin Li         return false;
2021*3e777be0SXin Li     }
2022*3e777be0SXin Li 
2023*3e777be0SXin Li     IConnectableLayer* const layer = data.m_Network->AddQuantizedLstmLayer(params, "QuantizedLstm");
2024*3e777be0SXin Li     layer->SetBackendId(setBackend);
2025*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
2026*3e777be0SXin Li     previousCellStateIn.Connect(layer->GetInputSlot(1));
2027*3e777be0SXin Li     previousOutputIn.Connect(layer->GetInputSlot(2));
2028*3e777be0SXin Li 
2029*3e777be0SXin Li     if (!isDynamic)
2030*3e777be0SXin Li     {
2031*3e777be0SXin Li         return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
2032*3e777be0SXin Li                 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data));
2033*3e777be0SXin Li     }
2034*3e777be0SXin Li     else
2035*3e777be0SXin Li     {
2036*3e777be0SXin Li         return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
2037*3e777be0SXin Li                 SetupAndTrackLayerOutputSlot<HalPolicy>(
2038*3e777be0SXin Li                     operation, 1, *layer, 1, model, data, nullptr, validateFunc, ActivationFn::kActivationNone, true));
2039*3e777be0SXin Li     }
2040*3e777be0SXin Li 
2041*3e777be0SXin Li }
2042*3e777be0SXin Li 
2043*3e777be0SXin Li template<typename HalPolicy,
2044*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
2045*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertReduce(const HalOperation & operation,const HalModel & model,ConversionData & data,ReduceOperation reduceOperation)2046*3e777be0SXin Li bool ConvertReduce(const HalOperation& operation,
2047*3e777be0SXin Li                    const HalModel& model,
2048*3e777be0SXin Li                    ConversionData& data,
2049*3e777be0SXin Li                    ReduceOperation reduceOperation)
2050*3e777be0SXin Li {
2051*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
2052*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
2053*3e777be0SXin Li 
2054*3e777be0SXin Li     armnn::ReduceDescriptor descriptor;
2055*3e777be0SXin Li     descriptor.m_ReduceOperation = reduceOperation;
2056*3e777be0SXin Li 
2057*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
2058*3e777be0SXin Li     if (!input.IsValid())
2059*3e777be0SXin Li     {
2060*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
2061*3e777be0SXin Li     }
2062*3e777be0SXin Li     const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
2063*3e777be0SXin Li 
2064*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
2065*3e777be0SXin Li     if (!output)
2066*3e777be0SXin Li     {
2067*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
2068*3e777be0SXin Li     }
2069*3e777be0SXin Li     const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
2070*3e777be0SXin Li 
2071*3e777be0SXin Li     const HalOperand* axisOperand = GetInputOperand<HalPolicy>(operation, 1, model);
2072*3e777be0SXin Li     if (!axisOperand)
2073*3e777be0SXin Li     {
2074*3e777be0SXin Li         return Fail("%s: Could not read input 1", __func__);
2075*3e777be0SXin Li     }
2076*3e777be0SXin Li     std::vector<int32_t> axis;
2077*3e777be0SXin Li     if (!GetTensorInt32Values<HalPolicy>(*axisOperand, axis, model, data))
2078*3e777be0SXin Li     {
2079*3e777be0SXin Li         return Fail("%s: Input 1 has invalid values", __func__);
2080*3e777be0SXin Li     }
2081*3e777be0SXin Li 
2082*3e777be0SXin Li     // Convert the axis to unsigned int and remove duplicates.
2083*3e777be0SXin Li     unsigned int rank = inputInfo.GetNumDimensions();
2084*3e777be0SXin Li     std::set<unsigned int> uniqueAxis;
2085*3e777be0SXin Li     std::transform(axis.begin(), axis.end(),
2086*3e777be0SXin Li                    std::inserter(uniqueAxis, uniqueAxis.begin()),
2087*3e777be0SXin Li                    [rank](int i) -> unsigned int { return (i + rank) % rank; });
2088*3e777be0SXin Li     descriptor.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
2089*3e777be0SXin Li 
2090*3e777be0SXin Li     // Get the "keep dims" flag.
2091*3e777be0SXin Li     if (!GetInputScalar<HalPolicy>(operation, 2, HalOperandType::BOOL, descriptor.m_KeepDims, model, data))
2092*3e777be0SXin Li     {
2093*3e777be0SXin Li         return Fail("%s: Could not read input 2", __func__);
2094*3e777be0SXin Li     }
2095*3e777be0SXin Li 
2096*3e777be0SXin Li     bool isSupported = false;
2097*3e777be0SXin Li     armnn::BackendId setBackend;
2098*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
2099*3e777be0SXin Li     {
2100*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
2101*3e777be0SXin Li                                    IsReduceSupported,
2102*3e777be0SXin Li                                    data.m_Backends,
2103*3e777be0SXin Li                                    isSupported,
2104*3e777be0SXin Li                                    setBackend,
2105*3e777be0SXin Li                                    inputInfo,
2106*3e777be0SXin Li                                    outputInfo,
2107*3e777be0SXin Li                                    descriptor);
2108*3e777be0SXin Li     };
2109*3e777be0SXin Li 
2110*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
2111*3e777be0SXin Li     {
2112*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
2113*3e777be0SXin Li     }
2114*3e777be0SXin Li     else
2115*3e777be0SXin Li     {
2116*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
2117*3e777be0SXin Li     }
2118*3e777be0SXin Li 
2119*3e777be0SXin Li     if (!isSupported)
2120*3e777be0SXin Li     {
2121*3e777be0SXin Li         return false;
2122*3e777be0SXin Li     }
2123*3e777be0SXin Li 
2124*3e777be0SXin Li     armnn::IConnectableLayer* const layer = data.m_Network->AddReduceLayer(descriptor);
2125*3e777be0SXin Li     layer->SetBackendId(setBackend);
2126*3e777be0SXin Li     if (!layer)
2127*3e777be0SXin Li     {
2128*3e777be0SXin Li         return Fail("%s: Could not add the ReduceLayer", __func__);
2129*3e777be0SXin Li     }
2130*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
2131*3e777be0SXin Li 
2132*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
2133*3e777be0SXin Li }
2134*3e777be0SXin Li 
2135*3e777be0SXin Li template<typename HalPolicy,
2136*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
2137*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertResize(const HalOperation & operation,const HalModel & model,ConversionData & data,ResizeMethod resizeMethod)2138*3e777be0SXin Li bool ConvertResize(const HalOperation& operation,
2139*3e777be0SXin Li                    const HalModel& model,
2140*3e777be0SXin Li                    ConversionData& data,
2141*3e777be0SXin Li                    ResizeMethod resizeMethod)
2142*3e777be0SXin Li {
2143*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
2144*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
2145*3e777be0SXin Li     ALOGV("HalPolicy::ConvertResize()");
2146*3e777be0SXin Li     ALOGV("resizeMethod = %s", GetResizeMethodAsCString(resizeMethod));
2147*3e777be0SXin Li 
2148*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
2149*3e777be0SXin Li     if (!input.IsValid())
2150*3e777be0SXin Li     {
2151*3e777be0SXin Li         return Fail("%s: Could not read input 0", __func__);
2152*3e777be0SXin Li     }
2153*3e777be0SXin Li 
2154*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
2155*3e777be0SXin Li     if (!output)
2156*3e777be0SXin Li     {
2157*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
2158*3e777be0SXin Li     }
2159*3e777be0SXin Li 
2160*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
2161*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
2162*3e777be0SXin Li 
2163*3e777be0SXin Li     ResizeDescriptor descriptor;
2164*3e777be0SXin Li     descriptor.m_Method     = resizeMethod;
2165*3e777be0SXin Li     descriptor.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 3, model, data);
2166*3e777be0SXin Li 
2167*3e777be0SXin Li     HalOperandType operandType1;
2168*3e777be0SXin Li     HalOperandType operandType2;
2169*3e777be0SXin Li 
2170*3e777be0SXin Li     if (!GetOperandType<HalPolicy>(operation, 1, model, operandType1) ||
2171*3e777be0SXin Li         !GetOperandType<HalPolicy>(operation, 2, model, operandType2))
2172*3e777be0SXin Li     {
2173*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
2174*3e777be0SXin Li     }
2175*3e777be0SXin Li 
2176*3e777be0SXin Li     if (operandType1 != operandType2)
2177*3e777be0SXin Li     {
2178*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs. Type of input 1 and 2 should be the same", __func__);
2179*3e777be0SXin Li     }
2180*3e777be0SXin Li 
2181*3e777be0SXin Li     if (operandType1 == HalOperandType::INT32)
2182*3e777be0SXin Li     {
2183*3e777be0SXin Li         // Case 1: resizing by shape
2184*3e777be0SXin Li         int32_t targetWidth  = 0;
2185*3e777be0SXin Li         int32_t targetHeight = 0;
2186*3e777be0SXin Li 
2187*3e777be0SXin Li         if (!GetInputInt32<HalPolicy>(operation, 1, targetWidth, model, data) ||
2188*3e777be0SXin Li             !GetInputInt32<HalPolicy>(operation, 2, targetHeight, model, data))
2189*3e777be0SXin Li         {
2190*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs for resizing by shape", __func__);
2191*3e777be0SXin Li         }
2192*3e777be0SXin Li 
2193*3e777be0SXin Li         if (targetWidth < 0 || targetHeight < 0)
2194*3e777be0SXin Li         {
2195*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs for resizing by shape. "
2196*3e777be0SXin Li                         "Target width/height cannot be < 0", __func__);
2197*3e777be0SXin Li         }
2198*3e777be0SXin Li 
2199*3e777be0SXin Li         descriptor.m_TargetWidth = static_cast<uint32_t>(targetWidth);
2200*3e777be0SXin Li         descriptor.m_TargetHeight = static_cast<uint32_t>(targetHeight);
2201*3e777be0SXin Li     }
2202*3e777be0SXin Li     else if (operandType1 == HalOperandType::FLOAT32)
2203*3e777be0SXin Li     {
2204*3e777be0SXin Li         // Case 2: resizing by scale
2205*3e777be0SXin Li         float widthScale  = 1.0f;
2206*3e777be0SXin Li         float heightScale = 1.0f;
2207*3e777be0SXin Li 
2208*3e777be0SXin Li         if (!GetInputFloat32<HalPolicy>(operation, 1, widthScale, model, data) ||
2209*3e777be0SXin Li             !GetInputFloat32<HalPolicy>(operation, 2, heightScale, model, data))
2210*3e777be0SXin Li         {
2211*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs for resizing by scale", __func__);
2212*3e777be0SXin Li         }
2213*3e777be0SXin Li 
2214*3e777be0SXin Li         const TensorShape& inputShape = inputInfo.GetShape();
2215*3e777be0SXin Li         armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
2216*3e777be0SXin Li 
2217*3e777be0SXin Li         float width  = inputShape[dataLayoutIndexed.GetWidthIndex()];
2218*3e777be0SXin Li         float height = inputShape[dataLayoutIndexed.GetHeightIndex()];
2219*3e777be0SXin Li 
2220*3e777be0SXin Li         descriptor.m_TargetWidth  = std::floor(width  * widthScale);
2221*3e777be0SXin Li         descriptor.m_TargetHeight = std::floor(height * heightScale);
2222*3e777be0SXin Li     }
2223*3e777be0SXin Li     else if (operandType1 == HalOperandType::FLOAT16)
2224*3e777be0SXin Li     {
2225*3e777be0SXin Li         Half widthScale;
2226*3e777be0SXin Li         Half heightScale;
2227*3e777be0SXin Li 
2228*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, widthScale, model, data) ||
2229*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 2, HalOperandType::FLOAT16, heightScale, model, data))
2230*3e777be0SXin Li         {
2231*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs for resizing by scale", __func__);
2232*3e777be0SXin Li         }
2233*3e777be0SXin Li 
2234*3e777be0SXin Li         const TensorShape& inputShape = inputInfo.GetShape();
2235*3e777be0SXin Li         armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
2236*3e777be0SXin Li 
2237*3e777be0SXin Li         Half width  = static_cast<Half>(inputShape[dataLayoutIndexed.GetWidthIndex()]);
2238*3e777be0SXin Li         Half height = static_cast<Half>(inputShape[dataLayoutIndexed.GetHeightIndex()]);
2239*3e777be0SXin Li 
2240*3e777be0SXin Li         descriptor.m_TargetWidth  = std::floor(width  * widthScale);
2241*3e777be0SXin Li         descriptor.m_TargetHeight = std::floor(height * heightScale);
2242*3e777be0SXin Li     }
2243*3e777be0SXin Li     else
2244*3e777be0SXin Li     {
2245*3e777be0SXin Li         return Fail("%s: Operand has invalid data type for resizing by scale", __func__);
2246*3e777be0SXin Li     }
2247*3e777be0SXin Li 
2248*3e777be0SXin Li     descriptor.m_AlignCorners     = GetOptionalBool<HalPolicy>(operation, 4, model, data);
2249*3e777be0SXin Li     descriptor.m_HalfPixelCenters = GetOptionalBool<HalPolicy>(operation, 5, model, data);
2250*3e777be0SXin Li 
2251*3e777be0SXin Li     bool isSupported = false;
2252*3e777be0SXin Li     armnn::BackendId setBackend;
2253*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
2254*3e777be0SXin Li     {
2255*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
2256*3e777be0SXin Li                                    IsResizeSupported,
2257*3e777be0SXin Li                                    data.m_Backends,
2258*3e777be0SXin Li                                    isSupported,
2259*3e777be0SXin Li                                    setBackend,
2260*3e777be0SXin Li                                    inputInfo,
2261*3e777be0SXin Li                                    outputInfo,
2262*3e777be0SXin Li                                    descriptor);
2263*3e777be0SXin Li         };
2264*3e777be0SXin Li 
2265*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
2266*3e777be0SXin Li     {
2267*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
2268*3e777be0SXin Li     }
2269*3e777be0SXin Li     else
2270*3e777be0SXin Li     {
2271*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
2272*3e777be0SXin Li     }
2273*3e777be0SXin Li 
2274*3e777be0SXin Li     if (!isSupported)
2275*3e777be0SXin Li     {
2276*3e777be0SXin Li         return false;
2277*3e777be0SXin Li     }
2278*3e777be0SXin Li 
2279*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddResizeLayer(descriptor);
2280*3e777be0SXin Li     layer->SetBackendId(setBackend);
2281*3e777be0SXin Li     if (!layer)
2282*3e777be0SXin Li     {
2283*3e777be0SXin Li         return Fail("%s: Could not add the ResizeLayer", __func__);
2284*3e777be0SXin Li     }
2285*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
2286*3e777be0SXin Li 
2287*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
2288*3e777be0SXin Li }
2289*3e777be0SXin Li 
2290*3e777be0SXin Li template<typename HalPolicy,
2291*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
2292*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertSpaceToDepth(const HalOperation & operation,const HalModel & model,ConversionData & data)2293*3e777be0SXin Li bool ConvertSpaceToDepth(const HalOperation& operation, const HalModel& model, ConversionData& data)
2294*3e777be0SXin Li {
2295*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
2296*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
2297*3e777be0SXin Li 
2298*3e777be0SXin Li     ALOGV("HalPolicy::ConvertSpaceToDepth()");
2299*3e777be0SXin Li 
2300*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
2301*3e777be0SXin Li     if (!input.IsValid() )
2302*3e777be0SXin Li     {
2303*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
2304*3e777be0SXin Li     }
2305*3e777be0SXin Li 
2306*3e777be0SXin Li     const TensorInfo& inputInfo = input.GetTensorInfo();
2307*3e777be0SXin Li     unsigned int rank = inputInfo.GetNumDimensions();
2308*3e777be0SXin Li     if (rank != 4)
2309*3e777be0SXin Li     {
2310*3e777be0SXin Li         return Fail("%s: Only inputs with rank 4 are supported", __func__);
2311*3e777be0SXin Li     }
2312*3e777be0SXin Li 
2313*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
2314*3e777be0SXin Li     if (!output)
2315*3e777be0SXin Li     {
2316*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
2317*3e777be0SXin Li     }
2318*3e777be0SXin Li 
2319*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
2320*3e777be0SXin Li 
2321*3e777be0SXin Li     SpaceToDepthDescriptor desc;
2322*3e777be0SXin Li 
2323*3e777be0SXin Li     GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, desc.m_BlockSize, model, data);
2324*3e777be0SXin Li 
2325*3e777be0SXin Li     if (desc.m_BlockSize <= 1)
2326*3e777be0SXin Li     {
2327*3e777be0SXin Li         return Fail("%s: Block size must be at least 1 in all dimensions");
2328*3e777be0SXin Li     }
2329*3e777be0SXin Li 
2330*3e777be0SXin Li     desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 2, model, data);
2331*3e777be0SXin Li 
2332*3e777be0SXin Li     bool isSupported = false;
2333*3e777be0SXin Li     armnn::BackendId setBackend;
2334*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
2335*3e777be0SXin Li     {
2336*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
2337*3e777be0SXin Li                                    IsSpaceToDepthSupported,
2338*3e777be0SXin Li                                    data.m_Backends,
2339*3e777be0SXin Li                                    isSupported,
2340*3e777be0SXin Li                                    setBackend,
2341*3e777be0SXin Li                                    inputInfo,
2342*3e777be0SXin Li                                    outputInfo,
2343*3e777be0SXin Li                                    desc);
2344*3e777be0SXin Li     };
2345*3e777be0SXin Li 
2346*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
2347*3e777be0SXin Li     {
2348*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
2349*3e777be0SXin Li     }
2350*3e777be0SXin Li     else
2351*3e777be0SXin Li     {
2352*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
2353*3e777be0SXin Li     }
2354*3e777be0SXin Li 
2355*3e777be0SXin Li     if (!isSupported)
2356*3e777be0SXin Li     {
2357*3e777be0SXin Li         return false;
2358*3e777be0SXin Li     }
2359*3e777be0SXin Li 
2360*3e777be0SXin Li     IConnectableLayer* const layer = data.m_Network->AddSpaceToDepthLayer(desc);
2361*3e777be0SXin Li     layer->SetBackendId(setBackend);
2362*3e777be0SXin Li     if (!layer)
2363*3e777be0SXin Li     {
2364*3e777be0SXin Li         return Fail("%s: Could not add the SpaceToDepthLayer", __func__);
2365*3e777be0SXin Li     }
2366*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
2367*3e777be0SXin Li 
2368*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
2369*3e777be0SXin Li }
2370*3e777be0SXin Li 
2371*3e777be0SXin Li template<typename HalPolicy,
2372*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
2373*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertSoftmax(const HalOperation & operation,const HalModel & model,ConversionData & data)2374*3e777be0SXin Li bool ConvertSoftmax(const HalOperation& operation, const HalModel& model, ConversionData& data)
2375*3e777be0SXin Li {
2376*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
2377*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
2378*3e777be0SXin Li 
2379*3e777be0SXin Li     ALOGV("HalPolicy::ConvertSoftmax()");
2380*3e777be0SXin Li 
2381*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
2382*3e777be0SXin Li     if (!input.IsValid())
2383*3e777be0SXin Li     {
2384*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
2385*3e777be0SXin Li     }
2386*3e777be0SXin Li 
2387*3e777be0SXin Li     const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
2388*3e777be0SXin Li     if (!outputOperand)
2389*3e777be0SXin Li     {
2390*3e777be0SXin Li         return Fail("%s: Operation has no outputs", __func__);
2391*3e777be0SXin Li     }
2392*3e777be0SXin Li 
2393*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*outputOperand);
2394*3e777be0SXin Li 
2395*3e777be0SXin Li     SoftmaxDescriptor desc;
2396*3e777be0SXin Li     HalOperandType outputType = outputOperand->type;
2397*3e777be0SXin Li 
2398*3e777be0SXin Li     // Read beta value
2399*3e777be0SXin Li     if (outputType == HalOperandType::TENSOR_FLOAT16)
2400*3e777be0SXin Li     {
2401*3e777be0SXin Li         Half value;
2402*3e777be0SXin Li 
2403*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
2404*3e777be0SXin Li         {
2405*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
2406*3e777be0SXin Li         }
2407*3e777be0SXin Li 
2408*3e777be0SXin Li         desc.m_Beta = static_cast<float>(value);
2409*3e777be0SXin Li     }
2410*3e777be0SXin Li     else
2411*3e777be0SXin Li     {
2412*3e777be0SXin Li         if (!GetInputFloat32<HalPolicy>(operation, 1, desc.m_Beta, model, data))
2413*3e777be0SXin Li         {
2414*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
2415*3e777be0SXin Li         }
2416*3e777be0SXin Li     }
2417*3e777be0SXin Li 
2418*3e777be0SXin Li     if (operation.inputs.size() > 2 && !GetInputScalar<HalPolicy>(operation,
2419*3e777be0SXin Li                                                                   2,
2420*3e777be0SXin Li                                                                   HalOperandType::INT32,
2421*3e777be0SXin Li                                                                   desc.m_Axis,
2422*3e777be0SXin Li                                                                   model,
2423*3e777be0SXin Li                                                                   data))
2424*3e777be0SXin Li     {
2425*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
2426*3e777be0SXin Li     }
2427*3e777be0SXin Li 
2428*3e777be0SXin Li     bool isSupported = false;
2429*3e777be0SXin Li     armnn::BackendId setBackend;
2430*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
2431*3e777be0SXin Li     {
2432*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
2433*3e777be0SXin Li                                    IsSoftmaxSupported,
2434*3e777be0SXin Li                                    data.m_Backends,
2435*3e777be0SXin Li                                    isSupported,
2436*3e777be0SXin Li                                    setBackend,
2437*3e777be0SXin Li                                    input.GetTensorInfo(),
2438*3e777be0SXin Li                                    outputInfo,
2439*3e777be0SXin Li                                    desc);
2440*3e777be0SXin Li         };
2441*3e777be0SXin Li 
2442*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
2443*3e777be0SXin Li     {
2444*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
2445*3e777be0SXin Li     }
2446*3e777be0SXin Li     else
2447*3e777be0SXin Li     {
2448*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
2449*3e777be0SXin Li     }
2450*3e777be0SXin Li 
2451*3e777be0SXin Li     if (!isSupported)
2452*3e777be0SXin Li     {
2453*3e777be0SXin Li         return false;
2454*3e777be0SXin Li     }
2455*3e777be0SXin Li 
2456*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddSoftmaxLayer(desc);
2457*3e777be0SXin Li     layer->SetBackendId(setBackend);
2458*3e777be0SXin Li     if (!layer)
2459*3e777be0SXin Li     {
2460*3e777be0SXin Li         return Fail("%s: Could not add the SoftmaxLayer", __func__);
2461*3e777be0SXin Li     }
2462*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
2463*3e777be0SXin Li 
2464*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
2465*3e777be0SXin Li }
2466*3e777be0SXin Li 
2467*3e777be0SXin Li template<typename HalPolicy,
2468*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
2469*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertLstm(const HalOperation & operation,const HalModel & model,ConversionData & data)2470*3e777be0SXin Li bool ConvertLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
2471*3e777be0SXin Li {
2472*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
2473*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
2474*3e777be0SXin Li 
2475*3e777be0SXin Li     ALOGV("HalPolicy::ConvertLstm()");
2476*3e777be0SXin Li 
2477*3e777be0SXin Li     // Inputs:
2478*3e777be0SXin Li     // 00: The input: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, input_size], where
2479*3e777be0SXin Li     //      “batch_size” corresponds to the batching dimension, and “input_size” is the size of the input.
2480*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
2481*3e777be0SXin Li     if (!input.IsValid())
2482*3e777be0SXin Li     {
2483*3e777be0SXin Li         return Fail("%s: Could not read input 0: input", __func__);
2484*3e777be0SXin Li     }
2485*3e777be0SXin Li     // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, output_size].
2486*3e777be0SXin Li     LayerInputHandle outputStateIn = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
2487*3e777be0SXin Li     if (!outputStateIn.IsValid())
2488*3e777be0SXin Li     {
2489*3e777be0SXin Li         return Fail("%s: Could not read input 18: outputStateIn", __func__);
2490*3e777be0SXin Li     }
2491*3e777be0SXin Li     // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, num_units].
2492*3e777be0SXin Li     LayerInputHandle cellStateIn = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
2493*3e777be0SXin Li     if (!cellStateIn.IsValid())
2494*3e777be0SXin Li     {
2495*3e777be0SXin Li         return Fail("%s: Could not read input 19: cellStateIn", __func__);
2496*3e777be0SXin Li     }
2497*3e777be0SXin Li 
2498*3e777be0SXin Li     // Get the mandatory input tensors:
2499*3e777be0SXin Li     // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2500*3e777be0SXin Li     //     [num_units, input_size].
2501*3e777be0SXin Li     const ConstTensorPin inputToForgetWeightsPin =
2502*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 2));
2503*3e777be0SXin Li     // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2504*3e777be0SXin Li     // [num_units, input_size].
2505*3e777be0SXin Li     const ConstTensorPin inputToCellWeightsPin =
2506*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 3));
2507*3e777be0SXin Li     // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2508*3e777be0SXin Li     //     [num_units, input_size].
2509*3e777be0SXin Li     const ConstTensorPin inputToOutputWeightsPin =
2510*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 4));
2511*3e777be0SXin Li     // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2512*3e777be0SXin Li     //     [num_units, output_size].
2513*3e777be0SXin Li     const ConstTensorPin recurrentToForgetWeightsPin =
2514*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 6));
2515*3e777be0SXin Li     // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2516*3e777be0SXin Li     //     [num_units, output_size].
2517*3e777be0SXin Li     const ConstTensorPin recurrentToCellWeightsPin =
2518*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 7));
2519*3e777be0SXin Li     // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2520*3e777be0SXin Li     //     [num_units, output_size].
2521*3e777be0SXin Li     const ConstTensorPin recurrentToOutputWeightsPin =
2522*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 8));
2523*3e777be0SXin Li     // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2524*3e777be0SXin Li     const ConstTensorPin forgetGateBiasPin =
2525*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
2526*3e777be0SXin Li     // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2527*3e777be0SXin Li     const ConstTensorPin cellBiasPin =
2528*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
2529*3e777be0SXin Li     // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2530*3e777be0SXin Li     const ConstTensorPin outputGateBiasPin =
2531*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
2532*3e777be0SXin Li 
2533*3e777be0SXin Li     if (!inputToForgetWeightsPin.IsValid() ||
2534*3e777be0SXin Li         !inputToCellWeightsPin.IsValid() ||
2535*3e777be0SXin Li         !inputToOutputWeightsPin.IsValid() ||
2536*3e777be0SXin Li         !recurrentToForgetWeightsPin.IsValid() ||
2537*3e777be0SXin Li         !recurrentToCellWeightsPin.IsValid() ||
2538*3e777be0SXin Li         !recurrentToOutputWeightsPin.IsValid() ||
2539*3e777be0SXin Li         !forgetGateBiasPin.IsValid() ||
2540*3e777be0SXin Li         !cellBiasPin.IsValid() ||
2541*3e777be0SXin Li         !outputGateBiasPin.IsValid())
2542*3e777be0SXin Li     {
2543*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
2544*3e777be0SXin Li     }
2545*3e777be0SXin Li 
2546*3e777be0SXin Li     // Get the optional input tensors:
2547*3e777be0SXin Li     // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2548*3e777be0SXin Li     //     [num_units, input_size], where “num_units” corresponds to the number of cell units.
2549*3e777be0SXin Li     const ConstTensorPin inputToInputWeightsPin =
2550*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1, true));
2551*3e777be0SXin Li     // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2552*3e777be0SXin Li     //     [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
2553*3e777be0SXin Li     //     “num_units”), or the second dimension of the “projection_weights”, if defined.
2554*3e777be0SXin Li     const ConstTensorPin recurrentToInputWeightsPin =
2555*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 5, true));
2556*3e777be0SXin Li     // 09: The cell-to-input weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2557*3e777be0SXin Li     const ConstTensorPin cellToInputWeightsPin =
2558*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 9, true));
2559*3e777be0SXin Li     // 10: The cell-to-forget weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2560*3e777be0SXin Li     const ConstTensorPin cellToForgetWeightsPin =
2561*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 10, true));
2562*3e777be0SXin Li     // 11: The cell-to-output weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2563*3e777be0SXin Li     const ConstTensorPin cellToOutputWeightsPin =
2564*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 11, true));
2565*3e777be0SXin Li     // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [num_units].
2566*3e777be0SXin Li     const ConstTensorPin inputGateBiasPin =
2567*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
2568*3e777be0SXin Li                                                          12,
2569*3e777be0SXin Li                                                          model,
2570*3e777be0SXin Li                                                          data,
2571*3e777be0SXin Li                                                          g_DontPermute,
2572*3e777be0SXin Li                                                          nullptr,
2573*3e777be0SXin Li                                                          true);
2574*3e777be0SXin Li 
2575*3e777be0SXin Li     // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
2576*3e777be0SXin Li     //     [output_size, num_units].
2577*3e777be0SXin Li     const ConstTensorPin projectionWeightsPin =
2578*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 16, true));
2579*3e777be0SXin Li     // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [output_size].
2580*3e777be0SXin Li     const ConstTensorPin projectionBiasPin =
2581*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
2582*3e777be0SXin Li                                                          17,
2583*3e777be0SXin Li                                                          model,
2584*3e777be0SXin Li                                                          data,
2585*3e777be0SXin Li                                                          g_DontPermute,
2586*3e777be0SXin Li                                                          nullptr,
2587*3e777be0SXin Li                                                          true);
2588*3e777be0SXin Li 
2589*3e777be0SXin Li     if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional()) ||
2590*3e777be0SXin Li         (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional()) ||
2591*3e777be0SXin Li         (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional()) ||
2592*3e777be0SXin Li         (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional()) ||
2593*3e777be0SXin Li         (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional()) ||
2594*3e777be0SXin Li         (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional()) ||
2595*3e777be0SXin Li         (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional()) ||
2596*3e777be0SXin Li         (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
2597*3e777be0SXin Li     {
2598*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
2599*3e777be0SXin Li     }
2600*3e777be0SXin Li 
2601*3e777be0SXin Li     // Get the mandatory input scalars (actually 1-D tensors of size 1):
2602*3e777be0SXin Li     // 20: The activation function: A value indicating the activation function:
2603*3e777be0SXin Li     //     0: None; 1: Relu; 3: Relu6; 4: Tanh; 6: Sigmoid.
2604*3e777be0SXin Li     // 21: The clipping threshold: for the cell state, such that values are bound within [-cell_clip, cell_clip].
2605*3e777be0SXin Li     //     If set to 0.0 then clipping is disabled.
2606*3e777be0SXin Li     // 22: The clipping threshold: for the output from the projection layer, such that values are bound within
2607*3e777be0SXin Li     //     [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
2608*3e777be0SXin Li     ActivationFn activation = ActivationFn::kActivationNone;
2609*3e777be0SXin Li     float cellClip;
2610*3e777be0SXin Li     float projClip;
2611*3e777be0SXin Li     if (!GetInputActivationFunctionFromTensor<HalPolicy>(operation, 20, activation, model, data) ||
2612*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 21, HalOperandType::FLOAT32, cellClip, model, data) ||
2613*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 22, HalOperandType::FLOAT32, projClip, model, data))
2614*3e777be0SXin Li     {
2615*3e777be0SXin Li         return Fail("%s: Operation has invalid scalar inputs", __func__);
2616*3e777be0SXin Li     }
2617*3e777be0SXin Li 
2618*3e777be0SXin Li     // Get the normalization tensors
2619*3e777be0SXin Li     // 23: The input layer normalization weights. A 1-D tensor of shape [num_units].
2620*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at input gate.
2621*3e777be0SXin Li     const ConstTensorPin inputLayerNormWeightsPin
2622*3e777be0SXin Li         (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 23, true));
2623*3e777be0SXin Li 
2624*3e777be0SXin Li     // 24: The forget layer normalization weights. A 1-D tensor of shape [num_units].
2625*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at forget gate.
2626*3e777be0SXin Li     const ConstTensorPin forgetLayerNormWeightsPin =
2627*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
2628*3e777be0SXin Li                                                         24,
2629*3e777be0SXin Li                                                         model,
2630*3e777be0SXin Li                                                         data,
2631*3e777be0SXin Li                                                         g_DontPermute,
2632*3e777be0SXin Li                                                         nullptr,
2633*3e777be0SXin Li                                                         true);
2634*3e777be0SXin Li 
2635*3e777be0SXin Li     // 25: The cell layer normalization weights. A 1-D tensor of shape [num_units].
2636*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at cell gate.
2637*3e777be0SXin Li     const ConstTensorPin cellLayerNormWeightsPin =
2638*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
2639*3e777be0SXin Li                                                          25,
2640*3e777be0SXin Li                                                          model,
2641*3e777be0SXin Li                                                          data,
2642*3e777be0SXin Li                                                          g_DontPermute,
2643*3e777be0SXin Li                                                          nullptr,
2644*3e777be0SXin Li                                                          true);
2645*3e777be0SXin Li 
2646*3e777be0SXin Li     // 26: The output layer normalization weights. A 1-D tensor of shape [num_units].
2647*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at output gate.
2648*3e777be0SXin Li     const ConstTensorPin outputLayerNormWeightsPin =
2649*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
2650*3e777be0SXin Li                                                          26,
2651*3e777be0SXin Li                                                          model,
2652*3e777be0SXin Li                                                          data,
2653*3e777be0SXin Li                                                          g_DontPermute,
2654*3e777be0SXin Li                                                          nullptr,
2655*3e777be0SXin Li                                                          true);
2656*3e777be0SXin Li 
2657*3e777be0SXin Li     // Outputs:
2658*3e777be0SXin Li     // 00: The scratch buffer: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, num_units * 4]
2659*3e777be0SXin Li     // with CIFG, or [batch_size, num_units * 3] without CIFG.
2660*3e777be0SXin Li     const HalOperand* scratchBuffer = GetOutputOperand<HalPolicy>(operation, 0, model);
2661*3e777be0SXin Li     if (!scratchBuffer)
2662*3e777be0SXin Li     {
2663*3e777be0SXin Li         return Fail("%s: Could not read output 0: scratchBuffer", __func__);
2664*3e777be0SXin Li     }
2665*3e777be0SXin Li     // 01: The output state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, output_size].
2666*3e777be0SXin Li     const HalOperand* outputStateOut = GetOutputOperand<HalPolicy>(operation, 1, model);
2667*3e777be0SXin Li     if (!outputStateOut)
2668*3e777be0SXin Li     {
2669*3e777be0SXin Li         return Fail("%s: Could not read output 1: outputStateOut", __func__);
2670*3e777be0SXin Li     }
2671*3e777be0SXin Li     // 02: The cell state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, num_units].
2672*3e777be0SXin Li     const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 2, model);
2673*3e777be0SXin Li     if (!cellStateOut)
2674*3e777be0SXin Li     {
2675*3e777be0SXin Li         return Fail("%s: Could not read output 2: cellStateOut", __func__);
2676*3e777be0SXin Li     }
2677*3e777be0SXin Li     // 03: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape [batch_size, output_size]. This is
2678*3e777be0SXin Li     //     effectively the same as the current “output state (out)” value.
2679*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 3, model);
2680*3e777be0SXin Li     if (!output)
2681*3e777be0SXin Li     {
2682*3e777be0SXin Li         return Fail("%s: Could not read output 3: output", __func__);
2683*3e777be0SXin Li     }
2684*3e777be0SXin Li 
2685*3e777be0SXin Li     // set the params structure for the AddLstmLayer call
2686*3e777be0SXin Li     LstmInputParams params;
2687*3e777be0SXin Li     params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
2688*3e777be0SXin Li     params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
2689*3e777be0SXin Li     params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
2690*3e777be0SXin Li     params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
2691*3e777be0SXin Li     params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
2692*3e777be0SXin Li     params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
2693*3e777be0SXin Li     params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
2694*3e777be0SXin Li     params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
2695*3e777be0SXin Li     params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
2696*3e777be0SXin Li     params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
2697*3e777be0SXin Li     params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
2698*3e777be0SXin Li     params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
2699*3e777be0SXin Li     params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
2700*3e777be0SXin Li     params.m_CellBias = cellBiasPin.GetConstTensorPtr();
2701*3e777be0SXin Li     params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
2702*3e777be0SXin Li     params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
2703*3e777be0SXin Li     params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
2704*3e777be0SXin Li     params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
2705*3e777be0SXin Li     params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
2706*3e777be0SXin Li     params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
2707*3e777be0SXin Li     params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
2708*3e777be0SXin Li 
2709*3e777be0SXin Li     // set the layer descriptor
2710*3e777be0SXin Li     LstmDescriptor desc;
2711*3e777be0SXin Li     desc.m_ActivationFunc = activation;
2712*3e777be0SXin Li     desc.m_ClippingThresCell = cellClip;
2713*3e777be0SXin Li     desc.m_ClippingThresProj = projClip;
2714*3e777be0SXin Li     desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
2715*3e777be0SXin Li                           params.m_RecurrentToInputWeights == nullptr ||
2716*3e777be0SXin Li                           params.m_InputGateBias == nullptr);
2717*3e777be0SXin Li     desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
2718*3e777be0SXin Li                               params.m_CellToOutputWeights != nullptr);
2719*3e777be0SXin Li     desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
2720*3e777be0SXin Li     desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
2721*3e777be0SXin Li                                params.m_ForgetLayerNormWeights != nullptr ||
2722*3e777be0SXin Li                                params.m_CellLayerNormWeights != nullptr ||
2723*3e777be0SXin Li                                params.m_OutputLayerNormWeights != nullptr);
2724*3e777be0SXin Li 
2725*3e777be0SXin Li     // validate the optional input groups
2726*3e777be0SXin Li     if (desc.m_CifgEnabled &&
2727*3e777be0SXin Li         (params.m_InputToInputWeights != nullptr ||
2728*3e777be0SXin Li          params.m_RecurrentToInputWeights != nullptr ||
2729*3e777be0SXin Li          params.m_InputGateBias != nullptr))
2730*3e777be0SXin Li     {
2731*3e777be0SXin Li         return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
2732*3e777be0SXin Li                     " and input gate bias must be provided", __func__);
2733*3e777be0SXin Li     }
2734*3e777be0SXin Li 
2735*3e777be0SXin Li     if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
2736*3e777be0SXin Li     {
2737*3e777be0SXin Li         return Fail("%s: projection bias should not be provided without projection weights", __func__);
2738*3e777be0SXin Li     }
2739*3e777be0SXin Li 
2740*3e777be0SXin Li     if (desc.m_PeepholeEnabled &&
2741*3e777be0SXin Li         (params.m_CellToForgetWeights == nullptr ||
2742*3e777be0SXin Li          params.m_CellToOutputWeights == nullptr ||
2743*3e777be0SXin Li          (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
2744*3e777be0SXin Li     {
2745*3e777be0SXin Li         return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
2746*3e777be0SXin Li                     " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
2747*3e777be0SXin Li     }
2748*3e777be0SXin Li 
2749*3e777be0SXin Li     if (desc.m_LayerNormEnabled &&
2750*3e777be0SXin Li         (params.m_ForgetLayerNormWeights == nullptr ||
2751*3e777be0SXin Li          params.m_CellLayerNormWeights == nullptr ||
2752*3e777be0SXin Li          params.m_OutputLayerNormWeights == nullptr ||
2753*3e777be0SXin Li          (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
2754*3e777be0SXin Li     {
2755*3e777be0SXin Li         return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
2756*3e777be0SXin Li                     " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
2757*3e777be0SXin Li     }
2758*3e777be0SXin Li 
2759*3e777be0SXin Li     // Check if the layer is supported
2760*3e777be0SXin Li     // Inputs
2761*3e777be0SXin Li     const TensorInfo& inputInfo         = input.GetTensorInfo();
2762*3e777be0SXin Li     const TensorInfo& outputStateInInfo = outputStateIn.GetTensorInfo();
2763*3e777be0SXin Li     const TensorInfo& cellStateInInfo   = cellStateIn.GetTensorInfo();
2764*3e777be0SXin Li 
2765*3e777be0SXin Li     // Outputs
2766*3e777be0SXin Li     const TensorInfo& scratchBufferInfo  = GetTensorInfoForOperand(*scratchBuffer);
2767*3e777be0SXin Li     const TensorInfo& outputStateOutInfo = GetTensorInfoForOperand(*outputStateOut);
2768*3e777be0SXin Li     const TensorInfo& cellStateOutInfo   = GetTensorInfoForOperand(*cellStateOut);
2769*3e777be0SXin Li     const TensorInfo& outputInfo         = GetTensorInfoForOperand(*output);
2770*3e777be0SXin Li 
2771*3e777be0SXin Li     // Basic parameters
2772*3e777be0SXin Li     LstmInputParamsInfo paramsInfo;
2773*3e777be0SXin Li     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
2774*3e777be0SXin Li     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
2775*3e777be0SXin Li     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
2776*3e777be0SXin Li     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
2777*3e777be0SXin Li     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
2778*3e777be0SXin Li     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
2779*3e777be0SXin Li     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
2780*3e777be0SXin Li     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
2781*3e777be0SXin Li     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
2782*3e777be0SXin Li 
2783*3e777be0SXin Li     // Optional parameters
2784*3e777be0SXin Li     if (!desc.m_CifgEnabled)
2785*3e777be0SXin Li     {
2786*3e777be0SXin Li         paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
2787*3e777be0SXin Li         paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
2788*3e777be0SXin Li         if (params.m_CellToInputWeights != nullptr)
2789*3e777be0SXin Li         {
2790*3e777be0SXin Li             paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
2791*3e777be0SXin Li         }
2792*3e777be0SXin Li         paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
2793*3e777be0SXin Li     }
2794*3e777be0SXin Li 
2795*3e777be0SXin Li     if (desc.m_ProjectionEnabled)
2796*3e777be0SXin Li     {
2797*3e777be0SXin Li         paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
2798*3e777be0SXin Li         if (params.m_ProjectionBias != nullptr)
2799*3e777be0SXin Li         {
2800*3e777be0SXin Li             paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
2801*3e777be0SXin Li         }
2802*3e777be0SXin Li     }
2803*3e777be0SXin Li 
2804*3e777be0SXin Li     if (desc.m_PeepholeEnabled)
2805*3e777be0SXin Li     {
2806*3e777be0SXin Li         paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
2807*3e777be0SXin Li         paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
2808*3e777be0SXin Li     }
2809*3e777be0SXin Li 
2810*3e777be0SXin Li     if (desc.m_LayerNormEnabled)
2811*3e777be0SXin Li     {
2812*3e777be0SXin Li         if(!desc.m_CifgEnabled)
2813*3e777be0SXin Li         {
2814*3e777be0SXin Li             paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
2815*3e777be0SXin Li         }
2816*3e777be0SXin Li         paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
2817*3e777be0SXin Li         paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
2818*3e777be0SXin Li         paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
2819*3e777be0SXin Li     }
2820*3e777be0SXin Li 
2821*3e777be0SXin Li     bool isSupported = false;
2822*3e777be0SXin Li     armnn::BackendId setBackend;
2823*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
2824*3e777be0SXin Li     {
2825*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
2826*3e777be0SXin Li                                    IsLstmSupported,
2827*3e777be0SXin Li                                    data.m_Backends,
2828*3e777be0SXin Li                                    isSupported,
2829*3e777be0SXin Li                                    setBackend,
2830*3e777be0SXin Li                                    inputInfo,
2831*3e777be0SXin Li                                    outputStateInInfo,
2832*3e777be0SXin Li                                    cellStateInInfo,
2833*3e777be0SXin Li                                    scratchBufferInfo,
2834*3e777be0SXin Li                                    outputStateOutInfo,
2835*3e777be0SXin Li                                    cellStateOutInfo,
2836*3e777be0SXin Li                                    outputInfo,
2837*3e777be0SXin Li                                    desc,
2838*3e777be0SXin Li                                    paramsInfo);
2839*3e777be0SXin Li     };
2840*3e777be0SXin Li 
2841*3e777be0SXin Li     bool isDynamic = false;
2842*3e777be0SXin Li     if (!IsDynamicTensor(outputStateOutInfo) &&
2843*3e777be0SXin Li         !IsDynamicTensor(scratchBufferInfo)  &&
2844*3e777be0SXin Li         !IsDynamicTensor(cellStateOutInfo)   &&
2845*3e777be0SXin Li         !IsDynamicTensor(outputInfo))
2846*3e777be0SXin Li     {
2847*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
2848*3e777be0SXin Li     }
2849*3e777be0SXin Li     else
2850*3e777be0SXin Li     {
2851*3e777be0SXin Li         isDynamic = true;
2852*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
2853*3e777be0SXin Li     }
2854*3e777be0SXin Li 
2855*3e777be0SXin Li     if (!isSupported)
2856*3e777be0SXin Li     {
2857*3e777be0SXin Li         return false;
2858*3e777be0SXin Li     }
2859*3e777be0SXin Li 
2860*3e777be0SXin Li     // Add the layer
2861*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddLstmLayer(desc, params, "Lstm");
2862*3e777be0SXin Li     layer->SetBackendId(setBackend);
2863*3e777be0SXin Li 
2864*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
2865*3e777be0SXin Li     outputStateIn.Connect(layer->GetInputSlot(1));
2866*3e777be0SXin Li     cellStateIn.Connect(layer->GetInputSlot(2));
2867*3e777be0SXin Li 
2868*3e777be0SXin Li     if (!isDynamic)
2869*3e777be0SXin Li     {
2870*3e777be0SXin Li         return (
2871*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
2872*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
2873*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data) &&
2874*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 3, *layer, 3, model, data));
2875*3e777be0SXin Li     }
2876*3e777be0SXin Li     else
2877*3e777be0SXin Li     {
2878*3e777be0SXin Li         return (
2879*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
2880*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
2881*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data) &&
2882*3e777be0SXin Li              SetupAndTrackLayerOutputSlot<HalPolicy>(
2883*3e777be0SXin Li                  operation, 3, *layer, 3, model, data, nullptr, validateFunc, ActivationFn::kActivationNone, true));
2884*3e777be0SXin Li     }
2885*3e777be0SXin Li 
2886*3e777be0SXin Li }
2887*3e777be0SXin Li 
2888*3e777be0SXin Li template<typename HalPolicy,
2889*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
2890*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertTransposeConv2d(const HalOperation & operation,const HalModel & model,ConversionData & data)2891*3e777be0SXin Li bool ConvertTransposeConv2d(const HalOperation& operation, const HalModel& model, ConversionData& data)
2892*3e777be0SXin Li {
2893*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
2894*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
2895*3e777be0SXin Li 
2896*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
2897*3e777be0SXin Li 
2898*3e777be0SXin Li     if (!input.IsValid())
2899*3e777be0SXin Li     {
2900*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
2901*3e777be0SXin Li     }
2902*3e777be0SXin Li 
2903*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
2904*3e777be0SXin Li 
2905*3e777be0SXin Li     if (!output)
2906*3e777be0SXin Li     {
2907*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
2908*3e777be0SXin Li     }
2909*3e777be0SXin Li 
2910*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
2911*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
2912*3e777be0SXin Li 
2913*3e777be0SXin Li     // ArmNN does not currently support non-fixed weights or bias
2914*3e777be0SXin Li     // Find the shape of the weights tensor. In AndroidNN this will be [ 1, H, W, I * M ]
2915*3e777be0SXin Li     const HalOperand* weightsOperand = GetInputOperand<HalPolicy>(operation, 1, model);
2916*3e777be0SXin Li 
2917*3e777be0SXin Li     if (weightsOperand == nullptr)
2918*3e777be0SXin Li     {
2919*3e777be0SXin Li         return Fail("%s: Operand is invalid", __func__);
2920*3e777be0SXin Li     }
2921*3e777be0SXin Li     TransposeConvolution2dDescriptor desc;
2922*3e777be0SXin Li     desc.m_DataLayout = DataLayout::NHWC;
2923*3e777be0SXin Li 
2924*3e777be0SXin Li     // Determine whether padding is implicit or explicit
2925*3e777be0SXin Li     bool implicitPadding = operation.inputs.size() == 9;
2926*3e777be0SXin Li 
2927*3e777be0SXin Li     if (implicitPadding )
2928*3e777be0SXin Li     {
2929*3e777be0SXin Li         desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 8, model, data);
2930*3e777be0SXin Li     }
2931*3e777be0SXin Li     else
2932*3e777be0SXin Li     {
2933*3e777be0SXin Li         desc.m_DataLayout = OptionalDataLayout<HalPolicy>(operation, 10, model, data);
2934*3e777be0SXin Li     }
2935*3e777be0SXin Li 
2936*3e777be0SXin Li     armnnUtils::DataLayoutIndexed dataLayoutIndexed(desc.m_DataLayout);
2937*3e777be0SXin Li     unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
2938*3e777be0SXin Li     unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
2939*3e777be0SXin Li 
2940*3e777be0SXin Li     const PermutationVector OHWIToOIHW = {0, 2, 3, 1};
2941*3e777be0SXin Li 
2942*3e777be0SXin Li     // The shape of the weight is [depth_out, filter_height, filter_width, depth_in].
2943*3e777be0SXin Li     // We have to permute it to OIHW if the data layout is NCHW.
2944*3e777be0SXin Li     const ConstTensorPin weightsPin = (desc.m_DataLayout == DataLayout::NCHW) ?
2945*3e777be0SXin Li                                       ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1,
2946*3e777be0SXin Li                                                                                        model, data, OHWIToOIHW) :
2947*3e777be0SXin Li                                       ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1, model, data);
2948*3e777be0SXin Li 
2949*3e777be0SXin Li     // Bias is a 1D tensor
2950*3e777be0SXin Li     const ConstTensorPin biasPin =
2951*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
2952*3e777be0SXin Li 
2953*3e777be0SXin Li     if (!weightsPin.IsValid())
2954*3e777be0SXin Li     {
2955*3e777be0SXin Li         return Fail("%s: Operation has invalid weights", __func__);
2956*3e777be0SXin Li     }
2957*3e777be0SXin Li 
2958*3e777be0SXin Li     if (!biasPin.IsValid())
2959*3e777be0SXin Li     {
2960*3e777be0SXin Li         return Fail("%s: Operation has invalid biases", __func__);
2961*3e777be0SXin Li     }
2962*3e777be0SXin Li 
2963*3e777be0SXin Li     ConstTensor weights = weightsPin.GetConstTensor();
2964*3e777be0SXin Li     ConstTensor bias = biasPin.GetConstTensor();
2965*3e777be0SXin Li     SanitizeBiasQuantizationScale(bias.GetInfo(), weights.GetInfo(), inputInfo);
2966*3e777be0SXin Li 
2967*3e777be0SXin Li     ActivationFn activation;
2968*3e777be0SXin Li 
2969*3e777be0SXin Li     if (implicitPadding)
2970*3e777be0SXin Li     {
2971*3e777be0SXin Li         int32_t strideX{0};
2972*3e777be0SXin Li         int32_t strideY{0};
2973*3e777be0SXin Li         int32_t padLeft{0};
2974*3e777be0SXin Li         int32_t padRight{0};
2975*3e777be0SXin Li         int32_t padTop{0};
2976*3e777be0SXin Li         int32_t padBottom{0};
2977*3e777be0SXin Li 
2978*3e777be0SXin Li         android::nn::PaddingScheme paddingScheme;
2979*3e777be0SXin Li         if (!GetInputPaddingScheme<HalPolicy>(operation, 4, paddingScheme, model, data) ||
2980*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, strideX, model, data) ||
2981*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, strideY, model, data) ||
2982*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation, 7, activation, model, data))
2983*3e777be0SXin Li         {
2984*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (implicit padding)", __func__);
2985*3e777be0SXin Li         }
2986*3e777be0SXin Li 
2987*3e777be0SXin Li         const uint32_t kernelX = weights.GetShape()[widthIndex];
2988*3e777be0SXin Li         const uint32_t kernelY = weights.GetShape()[heightIndex];
2989*3e777be0SXin Li 
2990*3e777be0SXin Li         // If output shape has been specified as a parameter then extract it and make it available.
2991*3e777be0SXin Li         const HalOperand* outputShapeOperand = GetInputOperand<HalPolicy>(operation, 3, model, false);
2992*3e777be0SXin Li         std::vector<int32_t> outputShape;
2993*3e777be0SXin Li         if ((outputShapeOperand) && (GetTensorInt32Values<HalPolicy>(*outputShapeOperand, outputShape, model, data)))
2994*3e777be0SXin Li         {
2995*3e777be0SXin Li             // Change from signed to unsigned int to store in TransposeConvolution2dDescriptor.
2996*3e777be0SXin Li             for (int dimension : outputShape)
2997*3e777be0SXin Li             {
2998*3e777be0SXin Li                 desc.m_OutputShape.push_back(static_cast<unsigned int>(dimension));
2999*3e777be0SXin Li             }
3000*3e777be0SXin Li             desc.m_OutputShapeEnabled = true;
3001*3e777be0SXin Li         }
3002*3e777be0SXin Li 
3003*3e777be0SXin Li         uint32_t outputX;
3004*3e777be0SXin Li         uint32_t outputY;
3005*3e777be0SXin Li 
3006*3e777be0SXin Li         if (IsDynamicTensor(outputInfo))
3007*3e777be0SXin Li         {
3008*3e777be0SXin Li             if (outputShape.size() == 0)
3009*3e777be0SXin Li             {
3010*3e777be0SXin Li                 return Fail("%s: Padding sizes cannot be inferred", __func__);
3011*3e777be0SXin Li             }
3012*3e777be0SXin Li 
3013*3e777be0SXin Li             outputX = outputShape[widthIndex];
3014*3e777be0SXin Li             outputY = outputShape[heightIndex];
3015*3e777be0SXin Li         }
3016*3e777be0SXin Li         else
3017*3e777be0SXin Li         {
3018*3e777be0SXin Li             outputX = outputInfo.GetShape()[widthIndex];
3019*3e777be0SXin Li             outputY = outputInfo.GetShape()[heightIndex];
3020*3e777be0SXin Li         }
3021*3e777be0SXin Li 
3022*3e777be0SXin Li         CalcPaddingTransposeConv(outputX, kernelX, strideX, padLeft, padRight, paddingScheme);
3023*3e777be0SXin Li         CalcPaddingTransposeConv(outputY, kernelY, strideY, padTop, padBottom, paddingScheme);
3024*3e777be0SXin Li 
3025*3e777be0SXin Li         // NOTE: The Android NN API allows for negative padding values in TransposeConv2d,
3026*3e777be0SXin Li         // but Arm NN only supports values >= 0
3027*3e777be0SXin Li         if (padLeft < 0 || padRight < 0 || padTop < 0 || padBottom < 0)
3028*3e777be0SXin Li         {
3029*3e777be0SXin Li             return Fail("%s: Negative padding values are not supported", __func__);
3030*3e777be0SXin Li         }
3031*3e777be0SXin Li 
3032*3e777be0SXin Li         desc.m_StrideX   = armnn::numeric_cast<uint32_t>(strideX);
3033*3e777be0SXin Li         desc.m_StrideY   = armnn::numeric_cast<uint32_t>(strideY);
3034*3e777be0SXin Li         desc.m_PadLeft   = armnn::numeric_cast<uint32_t>(padLeft);
3035*3e777be0SXin Li         desc.m_PadRight  = armnn::numeric_cast<uint32_t>(padRight);
3036*3e777be0SXin Li         desc.m_PadTop    = armnn::numeric_cast<uint32_t>(padTop);
3037*3e777be0SXin Li         desc.m_PadBottom = armnn::numeric_cast<uint32_t>(padBottom);
3038*3e777be0SXin Li     }
3039*3e777be0SXin Li     else if (operation.inputs.size() == 11)
3040*3e777be0SXin Li     {
3041*3e777be0SXin Li         // explicit padding
3042*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 3, HalOperandType::INT32, desc.m_PadLeft, model, data) ||
3043*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 4, HalOperandType::INT32, desc.m_PadRight, model, data) ||
3044*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 5, HalOperandType::INT32, desc.m_PadTop, model, data) ||
3045*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 6, HalOperandType::INT32, desc.m_PadBottom, model, data) ||
3046*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 7, HalOperandType::INT32, desc.m_StrideX, model, data) ||
3047*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 8, HalOperandType::INT32, desc.m_StrideY, model, data) ||
3048*3e777be0SXin Li             !GetInputActivationFunction<HalPolicy>(operation,  9, activation, model, data))
3049*3e777be0SXin Li         {
3050*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (explicit padding)", __func__);
3051*3e777be0SXin Li         }
3052*3e777be0SXin Li     }
3053*3e777be0SXin Li     else
3054*3e777be0SXin Li     {
3055*3e777be0SXin Li         return Fail("%s: Unsupported number of operation inputs", __func__);
3056*3e777be0SXin Li     }
3057*3e777be0SXin Li 
3058*3e777be0SXin Li     desc.m_BiasEnabled = true;
3059*3e777be0SXin Li     Optional<TensorInfo> biases(bias.GetInfo());
3060*3e777be0SXin Li 
3061*3e777be0SXin Li     bool isSupported = false;
3062*3e777be0SXin Li     armnn::BackendId setBackend;
3063*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
3064*3e777be0SXin Li     {
3065*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
3066*3e777be0SXin Li                                    IsTransposeConvolution2dSupported,
3067*3e777be0SXin Li                                    data.m_Backends,
3068*3e777be0SXin Li                                    isSupported,
3069*3e777be0SXin Li                                    setBackend,
3070*3e777be0SXin Li                                    inputInfo,
3071*3e777be0SXin Li                                    outputInfo,
3072*3e777be0SXin Li                                    desc,
3073*3e777be0SXin Li                                    weights.GetInfo(),
3074*3e777be0SXin Li                                    biases);
3075*3e777be0SXin Li     };
3076*3e777be0SXin Li 
3077*3e777be0SXin Li     if(IsDynamicTensor(outputInfo))
3078*3e777be0SXin Li     {
3079*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
3080*3e777be0SXin Li     }
3081*3e777be0SXin Li     else
3082*3e777be0SXin Li     {
3083*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
3084*3e777be0SXin Li     }
3085*3e777be0SXin Li     if (!isSupported)
3086*3e777be0SXin Li     {
3087*3e777be0SXin Li         return false;
3088*3e777be0SXin Li     }
3089*3e777be0SXin Li 
3090*3e777be0SXin Li     IConnectableLayer* startLayer =
3091*3e777be0SXin Li         data.m_Network->AddTransposeConvolution2dLayer(desc, weights, Optional<ConstTensor>(bias));
3092*3e777be0SXin Li     startLayer->SetBackendId(setBackend);
3093*3e777be0SXin Li     if (!startLayer)
3094*3e777be0SXin Li     {
3095*3e777be0SXin Li         return Fail("%s: AddTransposeConvolution2dLayer failed", __func__);
3096*3e777be0SXin Li     }
3097*3e777be0SXin Li 
3098*3e777be0SXin Li     input.Connect(startLayer->GetInputSlot(0));
3099*3e777be0SXin Li 
3100*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *startLayer, model,
3101*3e777be0SXin Li                                                    data, nullptr, validateFunc, activation);
3102*3e777be0SXin Li }
3103*3e777be0SXin Li 
3104*3e777be0SXin Li template<typename HalPolicy,
3105*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
3106*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertUnidirectionalSequenceLstm(const HalOperation & operation,const HalModel & model,ConversionData & data)3107*3e777be0SXin Li bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation,
3108*3e777be0SXin Li                                        const HalModel& model,
3109*3e777be0SXin Li                                        ConversionData& data)
3110*3e777be0SXin Li {
3111*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
3112*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
3113*3e777be0SXin Li 
3114*3e777be0SXin Li     ALOGV("HalPolicy::ConvertUnidirectionalSequenceLstm()");
3115*3e777be0SXin Li 
3116*3e777be0SXin Li     // Determine if input OperandType is ANEURALNETWORKS_TENSOR_FLOAT 32 or 16
3117*3e777be0SXin Li     HalOperandType inputType;
3118*3e777be0SXin Li     if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
3119*3e777be0SXin Li     {
3120*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
3121*3e777be0SXin Li     }
3122*3e777be0SXin Li 
3123*3e777be0SXin Li     // Inputs:
3124*3e777be0SXin Li     // 0: The input: A 3-D tensor of shape: If time-major: [max_time, batch_size, input_size] If batch-major:
3125*3e777be0SXin Li     // [batch_size, max_time, input_size] where “max_time” is the number of timesteps (sequence length), “batch_size”
3126*3e777be0SXin Li     // corresponds to the batching dimension, and “input_size” is the size of the input.
3127*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
3128*3e777be0SXin Li     if (!input.IsValid())
3129*3e777be0SXin Li     {
3130*3e777be0SXin Li         return Fail("%s: Could not read input 0: input", __func__);
3131*3e777be0SXin Li     }
3132*3e777be0SXin Li     // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [batch_size, output_size].
3133*3e777be0SXin Li     LayerInputHandle outputStateIn = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
3134*3e777be0SXin Li     if (!outputStateIn.IsValid())
3135*3e777be0SXin Li     {
3136*3e777be0SXin Li         return Fail("%s: Could not read input 18: outputStateIn", __func__);
3137*3e777be0SXin Li     }
3138*3e777be0SXin Li     // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [batch_size, num_units].
3139*3e777be0SXin Li     LayerInputHandle cellStateIn = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
3140*3e777be0SXin Li     if (!cellStateIn.IsValid())
3141*3e777be0SXin Li     {
3142*3e777be0SXin Li         return Fail("%s: Could not read input 19: cellStateIn", __func__);
3143*3e777be0SXin Li     }
3144*3e777be0SXin Li 
3145*3e777be0SXin Li     // Get the mandatory input tensors:
3146*3e777be0SXin Li     // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3147*3e777be0SXin Li     //     [num_units, input_size].
3148*3e777be0SXin Li     const ConstTensorPin inputToForgetWeightsPin =
3149*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 2));
3150*3e777be0SXin Li     // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3151*3e777be0SXin Li     // [num_units, input_size].
3152*3e777be0SXin Li     const ConstTensorPin inputToCellWeightsPin =
3153*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 3));
3154*3e777be0SXin Li     // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3155*3e777be0SXin Li     //     [num_units, input_size].
3156*3e777be0SXin Li     const ConstTensorPin inputToOutputWeightsPin =
3157*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 4));
3158*3e777be0SXin Li     // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3159*3e777be0SXin Li     //     [num_units, output_size].
3160*3e777be0SXin Li     const ConstTensorPin recurrentToForgetWeightsPin =
3161*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 6));
3162*3e777be0SXin Li     // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32, of shape
3163*3e777be0SXin Li     //     [num_units, output_size].
3164*3e777be0SXin Li     const ConstTensorPin recurrentToCellWeightsPin =
3165*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 7));
3166*3e777be0SXin Li     // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3167*3e777be0SXin Li     //     [num_units, output_size].
3168*3e777be0SXin Li     const ConstTensorPin recurrentToOutputWeightsPin =
3169*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 8));
3170*3e777be0SXin Li     // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3171*3e777be0SXin Li     const ConstTensorPin forgetGateBiasPin =
3172*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
3173*3e777be0SXin Li     // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3174*3e777be0SXin Li     const ConstTensorPin cellBiasPin =
3175*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
3176*3e777be0SXin Li     // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3177*3e777be0SXin Li     const ConstTensorPin outputGateBiasPin =
3178*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
3179*3e777be0SXin Li 
3180*3e777be0SXin Li     if (!inputToForgetWeightsPin.IsValid() ||
3181*3e777be0SXin Li         !inputToCellWeightsPin.IsValid() ||
3182*3e777be0SXin Li         !inputToOutputWeightsPin.IsValid() ||
3183*3e777be0SXin Li         !recurrentToForgetWeightsPin.IsValid() ||
3184*3e777be0SXin Li         !recurrentToCellWeightsPin.IsValid() ||
3185*3e777be0SXin Li         !recurrentToOutputWeightsPin.IsValid() ||
3186*3e777be0SXin Li         !forgetGateBiasPin.IsValid() ||
3187*3e777be0SXin Li         !cellBiasPin.IsValid() ||
3188*3e777be0SXin Li         !outputGateBiasPin.IsValid())
3189*3e777be0SXin Li     {
3190*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
3191*3e777be0SXin Li     }
3192*3e777be0SXin Li 
3193*3e777be0SXin Li     // Get the optional input tensors:
3194*3e777be0SXin Li     // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3195*3e777be0SXin Li     //     [num_units, input_size], where “num_units” corresponds to the number of cell units.
3196*3e777be0SXin Li     const ConstTensorPin inputToInputWeightsPin =
3197*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 1, true));
3198*3e777be0SXin Li     // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3199*3e777be0SXin Li     //     [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
3200*3e777be0SXin Li     //     “num_units”), or the second dimension of the “projection_weights”, if defined.
3201*3e777be0SXin Li     const ConstTensorPin recurrentToInputWeightsPin =
3202*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 5, true));
3203*3e777be0SXin Li     // 09: The cell-to-input weights: Optional.
3204*3e777be0SXin Li     // A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3205*3e777be0SXin Li     const ConstTensorPin cellToInputWeightsPin =
3206*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 9, true));
3207*3e777be0SXin Li     // 10: The cell-to-forget weights: Optional.
3208*3e777be0SXin Li     // A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3209*3e777be0SXin Li     const ConstTensorPin cellToForgetWeightsPin =
3210*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 10, true));
3211*3e777be0SXin Li     // 11: The cell-to-output weights: Optional.
3212*3e777be0SXin Li     // A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3213*3e777be0SXin Li     const ConstTensorPin cellToOutputWeightsPin =
3214*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 11, true));
3215*3e777be0SXin Li     // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [num_units].
3216*3e777be0SXin Li     const ConstTensorPin inputGateBiasPin =
3217*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
3218*3e777be0SXin Li                                                                               12,
3219*3e777be0SXin Li                                                                               model,
3220*3e777be0SXin Li                                                                               data,
3221*3e777be0SXin Li                                                                               g_DontPermute,
3222*3e777be0SXin Li                                                                               nullptr,
3223*3e777be0SXin Li                                                                               true);
3224*3e777be0SXin Li 
3225*3e777be0SXin Li     // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape
3226*3e777be0SXin Li     //     [output_size, num_units].
3227*3e777be0SXin Li     const ConstTensorPin projectionWeightsPin =
3228*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 16, true));
3229*3e777be0SXin Li     // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16, of shape [output_size].
3230*3e777be0SXin Li     const ConstTensorPin projectionBiasPin =
3231*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
3232*3e777be0SXin Li                                                                               17,
3233*3e777be0SXin Li                                                                               model,
3234*3e777be0SXin Li                                                                               data,
3235*3e777be0SXin Li                                                                               g_DontPermute,
3236*3e777be0SXin Li                                                                               nullptr,
3237*3e777be0SXin Li                                                                               true);
3238*3e777be0SXin Li 
3239*3e777be0SXin Li     if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional()) ||
3240*3e777be0SXin Li         (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional()) ||
3241*3e777be0SXin Li         (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional()) ||
3242*3e777be0SXin Li         (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional()) ||
3243*3e777be0SXin Li         (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional()) ||
3244*3e777be0SXin Li         (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional()) ||
3245*3e777be0SXin Li         (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional()) ||
3246*3e777be0SXin Li         (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
3247*3e777be0SXin Li     {
3248*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
3249*3e777be0SXin Li     }
3250*3e777be0SXin Li 
3251*3e777be0SXin Li     // Get the mandatory input scalars (actually 1-D tensors of size 1):
3252*3e777be0SXin Li     // 20: The activation function: A value indicating the activation function:
3253*3e777be0SXin Li     //     0: None; 1: Relu; 3: Relu6; 4: Tanh; 6: Sigmoid.
3254*3e777be0SXin Li     // 21: The clipping threshold: for the cell state, such that values are bound within [-cell_clip, cell_clip].
3255*3e777be0SXin Li     //     If set to 0.0 then clipping is disabled.
3256*3e777be0SXin Li     // 22: The clipping threshold: for the output from the projection layer, such that values are bound within
3257*3e777be0SXin Li     //     [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
3258*3e777be0SXin Li     // Determine data type of input tensor
3259*3e777be0SXin Li     ActivationFn activation = ActivationFn::kActivationNone;
3260*3e777be0SXin Li     LstmDescriptor desc;
3261*3e777be0SXin Li 
3262*3e777be0SXin Li     if (inputType == HalOperandType::TENSOR_FLOAT32)
3263*3e777be0SXin Li     {
3264*3e777be0SXin Li         float cellClip;
3265*3e777be0SXin Li         float projClip;
3266*3e777be0SXin Li 
3267*3e777be0SXin Li         if (!GetInputActivationFunctionFromTensor<HalPolicy>(operation, 20, activation, model, data) ||
3268*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 21, HalOperandType::FLOAT32, cellClip, model, data) ||
3269*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 22, HalOperandType::FLOAT32, projClip, model, data))
3270*3e777be0SXin Li         {
3271*3e777be0SXin Li             return Fail("%s: Operation has invalid scalar inputs", __func__);
3272*3e777be0SXin Li         }
3273*3e777be0SXin Li 
3274*3e777be0SXin Li         desc.m_ClippingThresCell = cellClip;
3275*3e777be0SXin Li         desc.m_ClippingThresProj = projClip;
3276*3e777be0SXin Li     }
3277*3e777be0SXin Li 
3278*3e777be0SXin Li     if (inputType == HalOperandType::TENSOR_FLOAT16)
3279*3e777be0SXin Li     {
3280*3e777be0SXin Li         Half cellClip;
3281*3e777be0SXin Li         Half projClip;
3282*3e777be0SXin Li 
3283*3e777be0SXin Li         if (!GetInputActivationFunctionFromTensor<HalPolicy>(operation, 20, activation, model, data) ||
3284*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 21, HalOperandType::FLOAT16, cellClip, model, data) ||
3285*3e777be0SXin Li             !GetInputScalar<HalPolicy>(operation, 22, HalOperandType::FLOAT16, projClip, model, data))
3286*3e777be0SXin Li         {
3287*3e777be0SXin Li             return Fail("%s: Operation has invalid scalar inputs", __func__);
3288*3e777be0SXin Li         }
3289*3e777be0SXin Li 
3290*3e777be0SXin Li         desc.m_ClippingThresCell = cellClip;
3291*3e777be0SXin Li         desc.m_ClippingThresProj = projClip;
3292*3e777be0SXin Li     }
3293*3e777be0SXin Li 
3294*3e777be0SXin Li     // Determine if time-major or batch-major.
3295*3e777be0SXin Li     // 23: Time-major if true, batch-major if false.
3296*3e777be0SXin Li     bool isTimeMajor = GetOptionalBool<HalPolicy>(operation, 23, model, data);
3297*3e777be0SXin Li 
3298*3e777be0SXin Li     // Get the normalization tensors
3299*3e777be0SXin Li     // 24: The input layer normalization weights. A 1-D tensor of shape [num_units].
3300*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at input gate.
3301*3e777be0SXin Li     const ConstTensorPin inputLayerNormWeightsPin
3302*3e777be0SXin Li                              (DequantizeAndMakeConstTensorPin<HalPolicy>(operation, model, data, 24, true));
3303*3e777be0SXin Li 
3304*3e777be0SXin Li     // 25: The forget layer normalization weights. A 1-D tensor of shape [num_units].
3305*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at forget gate.
3306*3e777be0SXin Li     const ConstTensorPin forgetLayerNormWeightsPin =
3307*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
3308*3e777be0SXin Li                                                                               25,
3309*3e777be0SXin Li                                                                               model,
3310*3e777be0SXin Li                                                                               data,
3311*3e777be0SXin Li                                                                               g_DontPermute,
3312*3e777be0SXin Li                                                                               nullptr,
3313*3e777be0SXin Li                                                                               true);
3314*3e777be0SXin Li 
3315*3e777be0SXin Li     // 26: The cell layer normalization weights. A 1-D tensor of shape [num_units].
3316*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at cell gate.
3317*3e777be0SXin Li     const ConstTensorPin cellLayerNormWeightsPin =
3318*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
3319*3e777be0SXin Li                                                                               26,
3320*3e777be0SXin Li                                                                               model,
3321*3e777be0SXin Li                                                                               data,
3322*3e777be0SXin Li                                                                               g_DontPermute,
3323*3e777be0SXin Li                                                                               nullptr,
3324*3e777be0SXin Li                                                                               true);
3325*3e777be0SXin Li 
3326*3e777be0SXin Li     // 27: The output layer normalization weights. A 1-D tensor of shape [num_units].
3327*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at output gate.
3328*3e777be0SXin Li     const ConstTensorPin outputLayerNormWeightsPin =
3329*3e777be0SXin Li                              ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
3330*3e777be0SXin Li                                                                               27,
3331*3e777be0SXin Li                                                                               model,
3332*3e777be0SXin Li                                                                               data,
3333*3e777be0SXin Li                                                                               g_DontPermute,
3334*3e777be0SXin Li                                                                               nullptr,
3335*3e777be0SXin Li                                                                               true);
3336*3e777be0SXin Li 
3337*3e777be0SXin Li     // Outputs:
3338*3e777be0SXin Li     // 00: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_FLOAT32/16. Shape:  if time-major:
3339*3e777be0SXin Li     // [max_time, batch_size, output_size] If batch-major: [batch_size, max_time, output_size]
3340*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
3341*3e777be0SXin Li     if (!output)
3342*3e777be0SXin Li     {
3343*3e777be0SXin Li         return Fail("%s: Could not read output: ", __func__);
3344*3e777be0SXin Li     }
3345*3e777be0SXin Li 
3346*3e777be0SXin Li     //
3347*3e777be0SXin Li     // 01 & 02:
3348*3e777be0SXin Li     // hiddenStateOut and cellStateOut are not currently supported by our android versioning.
3349*3e777be0SXin Li     //
3350*3e777be0SXin Li 
3351*3e777be0SXin Li     // set the params structure for the AddLstmLayer call
3352*3e777be0SXin Li     LstmInputParams params;
3353*3e777be0SXin Li     params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
3354*3e777be0SXin Li     params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
3355*3e777be0SXin Li     params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
3356*3e777be0SXin Li     params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
3357*3e777be0SXin Li     params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
3358*3e777be0SXin Li     params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
3359*3e777be0SXin Li     params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
3360*3e777be0SXin Li     params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
3361*3e777be0SXin Li     params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
3362*3e777be0SXin Li     params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
3363*3e777be0SXin Li     params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
3364*3e777be0SXin Li     params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
3365*3e777be0SXin Li     params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
3366*3e777be0SXin Li     params.m_CellBias = cellBiasPin.GetConstTensorPtr();
3367*3e777be0SXin Li     params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
3368*3e777be0SXin Li     params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
3369*3e777be0SXin Li     params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
3370*3e777be0SXin Li     params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
3371*3e777be0SXin Li     params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
3372*3e777be0SXin Li     params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
3373*3e777be0SXin Li     params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
3374*3e777be0SXin Li 
3375*3e777be0SXin Li     // set the layer descriptor
3376*3e777be0SXin Li     desc.m_ActivationFunc = activation;
3377*3e777be0SXin Li     desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
3378*3e777be0SXin Li         params.m_RecurrentToInputWeights == nullptr ||
3379*3e777be0SXin Li         params.m_InputGateBias == nullptr);
3380*3e777be0SXin Li     desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
3381*3e777be0SXin Li         params.m_CellToOutputWeights != nullptr);
3382*3e777be0SXin Li     desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
3383*3e777be0SXin Li     desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
3384*3e777be0SXin Li         params.m_ForgetLayerNormWeights != nullptr ||
3385*3e777be0SXin Li         params.m_CellLayerNormWeights != nullptr ||
3386*3e777be0SXin Li         params.m_OutputLayerNormWeights != nullptr);
3387*3e777be0SXin Li     desc.m_TimeMajor = isTimeMajor;
3388*3e777be0SXin Li 
3389*3e777be0SXin Li     // validate the optional input groups
3390*3e777be0SXin Li     if (desc.m_CifgEnabled &&
3391*3e777be0SXin Li         (params.m_InputToInputWeights != nullptr ||
3392*3e777be0SXin Li             params.m_RecurrentToInputWeights != nullptr ||
3393*3e777be0SXin Li             params.m_InputGateBias != nullptr))
3394*3e777be0SXin Li     {
3395*3e777be0SXin Li         return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
3396*3e777be0SXin Li                     " and input gate bias must be provided", __func__);
3397*3e777be0SXin Li     }
3398*3e777be0SXin Li 
3399*3e777be0SXin Li     if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
3400*3e777be0SXin Li     {
3401*3e777be0SXin Li         return Fail("%s: projection bias should not be provided without projection weights", __func__);
3402*3e777be0SXin Li     }
3403*3e777be0SXin Li 
3404*3e777be0SXin Li     if (desc.m_PeepholeEnabled &&
3405*3e777be0SXin Li         (params.m_CellToForgetWeights == nullptr ||
3406*3e777be0SXin Li             params.m_CellToOutputWeights == nullptr ||
3407*3e777be0SXin Li             (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
3408*3e777be0SXin Li     {
3409*3e777be0SXin Li         return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
3410*3e777be0SXin Li                     " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
3411*3e777be0SXin Li     }
3412*3e777be0SXin Li 
3413*3e777be0SXin Li     if (desc.m_LayerNormEnabled &&
3414*3e777be0SXin Li         (params.m_ForgetLayerNormWeights == nullptr ||
3415*3e777be0SXin Li             params.m_CellLayerNormWeights == nullptr ||
3416*3e777be0SXin Li             params.m_OutputLayerNormWeights == nullptr ||
3417*3e777be0SXin Li             (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
3418*3e777be0SXin Li     {
3419*3e777be0SXin Li         return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
3420*3e777be0SXin Li                     " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
3421*3e777be0SXin Li     }
3422*3e777be0SXin Li 
3423*3e777be0SXin Li     // Check if the layer is supported
3424*3e777be0SXin Li     // Inputs
3425*3e777be0SXin Li     const TensorInfo& inputInfo         = input.GetTensorInfo();
3426*3e777be0SXin Li     const TensorInfo& outputStateInInfo = outputStateIn.GetTensorInfo();
3427*3e777be0SXin Li     const TensorInfo& cellStateInInfo   = cellStateIn.GetTensorInfo();
3428*3e777be0SXin Li 
3429*3e777be0SXin Li     // Outputs
3430*3e777be0SXin Li     const TensorInfo& outputInfo         = GetTensorInfoForOperand(*output);
3431*3e777be0SXin Li 
3432*3e777be0SXin Li     unsigned int batchSize               = inputInfo.GetShape()[0];
3433*3e777be0SXin Li     unsigned int outputSize              = outputInfo.GetShape()[2];
3434*3e777be0SXin Li     unsigned int numUnits                = cellStateInInfo.GetShape()[1];
3435*3e777be0SXin Li 
3436*3e777be0SXin Li     armnn::DataType dataType             = inputInfo.GetDataType();
3437*3e777be0SXin Li     float qScale                         = inputInfo.GetQuantizationScale();
3438*3e777be0SXin Li     int qOffset                          = inputInfo.GetQuantizationOffset();
3439*3e777be0SXin Li 
3440*3e777be0SXin Li     armnn::TensorInfo cellStateOutInfo({batchSize, numUnits}, cellStateInInfo.GetDataType(),
3441*3e777be0SXin Li                                        cellStateInInfo.GetQuantizationScale(), cellStateInInfo.GetQuantizationOffset());
3442*3e777be0SXin Li     armnn::TensorInfo outputStateOutInfo({batchSize, outputSize}, dataType, qScale, qOffset);
3443*3e777be0SXin Li 
3444*3e777be0SXin Li     // Basic parameters
3445*3e777be0SXin Li     LstmInputParamsInfo paramsInfo;
3446*3e777be0SXin Li     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
3447*3e777be0SXin Li     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
3448*3e777be0SXin Li     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
3449*3e777be0SXin Li     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
3450*3e777be0SXin Li     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
3451*3e777be0SXin Li     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
3452*3e777be0SXin Li     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
3453*3e777be0SXin Li     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
3454*3e777be0SXin Li     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
3455*3e777be0SXin Li 
3456*3e777be0SXin Li     // Optional parameters
3457*3e777be0SXin Li     if (!desc.m_CifgEnabled)
3458*3e777be0SXin Li     {
3459*3e777be0SXin Li         paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
3460*3e777be0SXin Li         paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
3461*3e777be0SXin Li         if (params.m_CellToInputWeights != nullptr)
3462*3e777be0SXin Li         {
3463*3e777be0SXin Li             paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
3464*3e777be0SXin Li         }
3465*3e777be0SXin Li         paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
3466*3e777be0SXin Li     }
3467*3e777be0SXin Li 
3468*3e777be0SXin Li     if (desc.m_ProjectionEnabled)
3469*3e777be0SXin Li     {
3470*3e777be0SXin Li         paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
3471*3e777be0SXin Li         if (params.m_ProjectionBias != nullptr)
3472*3e777be0SXin Li         {
3473*3e777be0SXin Li             paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
3474*3e777be0SXin Li         }
3475*3e777be0SXin Li     }
3476*3e777be0SXin Li 
3477*3e777be0SXin Li     if (desc.m_PeepholeEnabled)
3478*3e777be0SXin Li     {
3479*3e777be0SXin Li         paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
3480*3e777be0SXin Li         paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
3481*3e777be0SXin Li     }
3482*3e777be0SXin Li 
3483*3e777be0SXin Li     if (desc.m_LayerNormEnabled)
3484*3e777be0SXin Li     {
3485*3e777be0SXin Li         if(!desc.m_CifgEnabled)
3486*3e777be0SXin Li         {
3487*3e777be0SXin Li             paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
3488*3e777be0SXin Li         }
3489*3e777be0SXin Li         paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
3490*3e777be0SXin Li         paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
3491*3e777be0SXin Li         paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
3492*3e777be0SXin Li     }
3493*3e777be0SXin Li 
3494*3e777be0SXin Li     bool isSupported = false;
3495*3e777be0SXin Li     armnn::BackendId setBackend;
3496*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
3497*3e777be0SXin Li     {
3498*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
3499*3e777be0SXin Li                                    IsUnidirectionalSequenceLstmSupported,
3500*3e777be0SXin Li                                    data.m_Backends,
3501*3e777be0SXin Li                                    isSupported,
3502*3e777be0SXin Li                                    setBackend,
3503*3e777be0SXin Li                                    inputInfo,
3504*3e777be0SXin Li                                    outputStateInInfo,
3505*3e777be0SXin Li                                    cellStateInInfo,
3506*3e777be0SXin Li                                    outputStateOutInfo,
3507*3e777be0SXin Li                                    cellStateOutInfo,
3508*3e777be0SXin Li                                    outputInfo,
3509*3e777be0SXin Li                                    desc,
3510*3e777be0SXin Li                                    paramsInfo);
3511*3e777be0SXin Li     };
3512*3e777be0SXin Li 
3513*3e777be0SXin Li     bool isDynamic = false;
3514*3e777be0SXin Li     if (!IsDynamicTensor(outputInfo))
3515*3e777be0SXin Li     {
3516*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
3517*3e777be0SXin Li     }
3518*3e777be0SXin Li     else
3519*3e777be0SXin Li     {
3520*3e777be0SXin Li         isDynamic = true;
3521*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
3522*3e777be0SXin Li     }
3523*3e777be0SXin Li 
3524*3e777be0SXin Li     if (!isSupported)
3525*3e777be0SXin Li     {
3526*3e777be0SXin Li         return false;
3527*3e777be0SXin Li     }
3528*3e777be0SXin Li 
3529*3e777be0SXin Li     // Add the layer
3530*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddUnidirectionalSequenceLstmLayer(desc,
3531*3e777be0SXin Li                                                                                   params,
3532*3e777be0SXin Li                                                                                   "UnidirectionalSequenceLstm");
3533*3e777be0SXin Li     layer->SetBackendId(setBackend);
3534*3e777be0SXin Li 
3535*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
3536*3e777be0SXin Li     outputStateIn.Connect(layer->GetInputSlot(1));
3537*3e777be0SXin Li     cellStateIn.Connect(layer->GetInputSlot(2));
3538*3e777be0SXin Li 
3539*3e777be0SXin Li     if (!isDynamic)
3540*3e777be0SXin Li     {
3541*3e777be0SXin Li         return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 2, model, data));
3542*3e777be0SXin Li     }
3543*3e777be0SXin Li     else
3544*3e777be0SXin Li     {
3545*3e777be0SXin Li         return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 2, model, data, nullptr,
3546*3e777be0SXin Li                                                         validateFunc, ActivationFn::kActivationNone, true));
3547*3e777be0SXin Li     }
3548*3e777be0SXin Li }
3549*3e777be0SXin Li 
3550*3e777be0SXin Li } // armnn_driver namespace