xref: /aosp_15_r20/external/android-nn-driver/ConversionUtils_1_3.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2020,2022 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "ConversionUtils_1_2.hpp"
9*3e777be0SXin Li 
10*3e777be0SXin Li using Half = half_float::half;
11*3e777be0SXin Li 
12*3e777be0SXin Li namespace armnn_driver
13*3e777be0SXin Li {
14*3e777be0SXin Li 
15*3e777be0SXin Li using namespace armnn;
16*3e777be0SXin Li using namespace android::nn;
17*3e777be0SXin Li 
18*3e777be0SXin Li template<typename HalPolicy,
19*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
20*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertElu(const HalOperation & operation,const HalModel & model,ConversionData & data)21*3e777be0SXin Li bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
22*3e777be0SXin Li {
23*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
24*3e777be0SXin Li 
25*3e777be0SXin Li     LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
26*3e777be0SXin Li     if (!input0.IsValid())
27*3e777be0SXin Li     {
28*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
29*3e777be0SXin Li     }
30*3e777be0SXin Li 
31*3e777be0SXin Li     // Determine data type of input tensor
32*3e777be0SXin Li     HalOperandType inputType;
33*3e777be0SXin Li     if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
34*3e777be0SXin Li     {
35*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
36*3e777be0SXin Li     }
37*3e777be0SXin Li 
38*3e777be0SXin Li     ActivationDescriptor desc;
39*3e777be0SXin Li     desc.m_Function = ActivationFunction::Elu;
40*3e777be0SXin Li 
41*3e777be0SXin Li     // Read alpha
42*3e777be0SXin Li     if (inputType == HalOperandType::TENSOR_FLOAT16)
43*3e777be0SXin Li     {
44*3e777be0SXin Li         Half alpha;
45*3e777be0SXin Li 
46*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
47*3e777be0SXin Li         {
48*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
49*3e777be0SXin Li         }
50*3e777be0SXin Li 
51*3e777be0SXin Li         desc.m_A = static_cast<float>(alpha);
52*3e777be0SXin Li     }
53*3e777be0SXin Li     else if (inputType == HalOperandType::TENSOR_FLOAT32)
54*3e777be0SXin Li     {
55*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
56*3e777be0SXin Li         {
57*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
58*3e777be0SXin Li         }
59*3e777be0SXin Li     }
60*3e777be0SXin Li     else
61*3e777be0SXin Li     {
62*3e777be0SXin Li         return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
63*3e777be0SXin Li     }
64*3e777be0SXin Li 
65*3e777be0SXin Li     return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
66*3e777be0SXin Li }
67*3e777be0SXin Li 
68*3e777be0SXin Li template<typename HalPolicy,
69*3e777be0SXin Li     typename HalOperation = typename HalPolicy::Operation,
70*3e777be0SXin Li     typename HalModel     = typename HalPolicy::Model>
ConvertFill(const HalOperation & operation,const HalModel & model,ConversionData & data)71*3e777be0SXin Li bool ConvertFill(const HalOperation& operation, const HalModel& model, ConversionData& data)
72*3e777be0SXin Li {
73*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
74*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
75*3e777be0SXin Li 
76*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
77*3e777be0SXin Li     if (!input.IsValid())
78*3e777be0SXin Li     {
79*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
80*3e777be0SXin Li     }
81*3e777be0SXin Li 
82*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
83*3e777be0SXin Li     if (!output)
84*3e777be0SXin Li     {
85*3e777be0SXin Li         return Fail("%s: Could not read output", __func__);
86*3e777be0SXin Li     }
87*3e777be0SXin Li 
88*3e777be0SXin Li     const TensorInfo& inputInfo  = input.GetTensorInfo();
89*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
90*3e777be0SXin Li     if (IsDynamicTensor(outputInfo))
91*3e777be0SXin Li     {
92*3e777be0SXin Li         return Fail("%s: Dynamic output tensors are not supported", __func__);
93*3e777be0SXin Li     }
94*3e777be0SXin Li 
95*3e777be0SXin Li     // Determine data type of output tensor
96*3e777be0SXin Li     HalOperandType outputType = output->type;
97*3e777be0SXin Li     FillDescriptor descriptor;
98*3e777be0SXin Li     // Read the scalar fill value
99*3e777be0SXin Li     if (outputType == HalOperandType::TENSOR_FLOAT16)
100*3e777be0SXin Li     {
101*3e777be0SXin Li         Half value;
102*3e777be0SXin Li 
103*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
104*3e777be0SXin Li         {
105*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
106*3e777be0SXin Li         }
107*3e777be0SXin Li 
108*3e777be0SXin Li         descriptor.m_Value = static_cast<float>(value);
109*3e777be0SXin Li     }
110*3e777be0SXin Li     else if (outputType == HalOperandType::TENSOR_FLOAT32)
111*3e777be0SXin Li     {
112*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, descriptor.m_Value, model, data))
113*3e777be0SXin Li         {
114*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
115*3e777be0SXin Li         }
116*3e777be0SXin Li     }
117*3e777be0SXin Li     else if (outputType == HalOperandType::TENSOR_INT32)
118*3e777be0SXin Li     {
119*3e777be0SXin Li         int32_t value;
120*3e777be0SXin Li 
121*3e777be0SXin Li         if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, value, model, data))
122*3e777be0SXin Li         {
123*3e777be0SXin Li             return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
124*3e777be0SXin Li         }
125*3e777be0SXin Li 
126*3e777be0SXin Li         descriptor.m_Value = static_cast<float>(value);
127*3e777be0SXin Li     }
128*3e777be0SXin Li     else
129*3e777be0SXin Li     {
130*3e777be0SXin Li         return Fail("%s: Unsupported input tensor type: %d", __func__, outputType);
131*3e777be0SXin Li     }
132*3e777be0SXin Li 
133*3e777be0SXin Li     bool isSupported = false;
134*3e777be0SXin Li     armnn::BackendId setBackend;
135*3e777be0SXin Li     FORWARD_LAYER_SUPPORT_FUNC(__func__,
136*3e777be0SXin Li                                IsFillSupported,
137*3e777be0SXin Li                                data.m_Backends,
138*3e777be0SXin Li                                isSupported,
139*3e777be0SXin Li                                setBackend,
140*3e777be0SXin Li                                inputInfo,
141*3e777be0SXin Li                                outputInfo,
142*3e777be0SXin Li                                descriptor);
143*3e777be0SXin Li     if (!isSupported)
144*3e777be0SXin Li     {
145*3e777be0SXin Li         return false;
146*3e777be0SXin Li     }
147*3e777be0SXin Li 
148*3e777be0SXin Li     IConnectableLayer* const layer = data.m_Network->AddFillLayer(descriptor);
149*3e777be0SXin Li     layer->SetBackendId(setBackend);
150*3e777be0SXin Li     if (!layer)
151*3e777be0SXin Li     {
152*3e777be0SXin Li         return Fail("%s: Could not add the FillLayer", __func__);
153*3e777be0SXin Li     }
154*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
155*3e777be0SXin Li 
156*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
157*3e777be0SXin Li }
158*3e777be0SXin Li 
159*3e777be0SXin Li template<typename HalPolicy,
160*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
161*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertLogicalBinary(const HalOperation & operation,const HalModel & model,ConversionData & data,LogicalBinaryOperation logicalOperation)162*3e777be0SXin Li bool ConvertLogicalBinary(const HalOperation& operation,
163*3e777be0SXin Li                           const HalModel& model,
164*3e777be0SXin Li                           ConversionData& data,
165*3e777be0SXin Li                           LogicalBinaryOperation logicalOperation)
166*3e777be0SXin Li {
167*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
168*3e777be0SXin Li 
169*3e777be0SXin Li     ALOGV("HalPolicy::ConvertLogicalBinary()");
170*3e777be0SXin Li     ALOGV("logicalOperation = %s", GetLogicalBinaryOperationAsCString(logicalOperation));
171*3e777be0SXin Li 
172*3e777be0SXin Li     LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
173*3e777be0SXin Li     LayerInputHandle input1 = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
174*3e777be0SXin Li 
175*3e777be0SXin Li     if (!(input0.IsValid() && input1.IsValid()))
176*3e777be0SXin Li     {
177*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
178*3e777be0SXin Li     }
179*3e777be0SXin Li 
180*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
181*3e777be0SXin Li     if (!output)
182*3e777be0SXin Li     {
183*3e777be0SXin Li         return Fail("%s: Could not read output 0", __func__);
184*3e777be0SXin Li     }
185*3e777be0SXin Li 
186*3e777be0SXin Li     const TensorInfo& inputInfo0 = input0.GetTensorInfo();
187*3e777be0SXin Li     const TensorInfo& inputInfo1 = input1.GetTensorInfo();
188*3e777be0SXin Li     const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
189*3e777be0SXin Li 
190*3e777be0SXin Li     LogicalBinaryDescriptor descriptor(logicalOperation);
191*3e777be0SXin Li 
192*3e777be0SXin Li     bool isSupported = false;
193*3e777be0SXin Li     armnn::BackendId setBackend;
194*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
195*3e777be0SXin Li     {
196*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
197*3e777be0SXin Li                                    IsLogicalBinarySupported,
198*3e777be0SXin Li                                    data.m_Backends,
199*3e777be0SXin Li                                    isSupported,
200*3e777be0SXin Li                                    setBackend,
201*3e777be0SXin Li                                    inputInfo0,
202*3e777be0SXin Li                                    inputInfo1,
203*3e777be0SXin Li                                    outputInfo,
204*3e777be0SXin Li                                    descriptor);
205*3e777be0SXin Li     };
206*3e777be0SXin Li 
207*3e777be0SXin Li     if(!IsDynamicTensor(outputInfo))
208*3e777be0SXin Li     {
209*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
210*3e777be0SXin Li     }
211*3e777be0SXin Li     else
212*3e777be0SXin Li     {
213*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
214*3e777be0SXin Li     }
215*3e777be0SXin Li 
216*3e777be0SXin Li     if (!isSupported)
217*3e777be0SXin Li     {
218*3e777be0SXin Li         return false;
219*3e777be0SXin Li     }
220*3e777be0SXin Li 
221*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddLogicalBinaryLayer(descriptor);
222*3e777be0SXin Li     layer->SetBackendId(setBackend);
223*3e777be0SXin Li     if (!layer)
224*3e777be0SXin Li     {
225*3e777be0SXin Li         return Fail("%s: Could not add the LogicalBinaryLayer", __func__);
226*3e777be0SXin Li     }
227*3e777be0SXin Li 
228*3e777be0SXin Li     bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
229*3e777be0SXin Li     if (!isReshapeSupported)
230*3e777be0SXin Li     {
231*3e777be0SXin Li         return false;
232*3e777be0SXin Li     }
233*3e777be0SXin Li 
234*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
235*3e777be0SXin Li }
236*3e777be0SXin Li 
237*3e777be0SXin Li template<typename HalPolicy,
238*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
239*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertQuantizedLstm(const HalOperation & operation,const HalModel & model,ConversionData & data)240*3e777be0SXin Li bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
241*3e777be0SXin Li {
242*3e777be0SXin Li     using HalOperand     = typename HalPolicy::Operand;
243*3e777be0SXin Li     using HalOperandType = typename HalPolicy::OperandType;
244*3e777be0SXin Li 
245*3e777be0SXin Li     ALOGV("HalPolicy::ConvertQuantizedLstm()");
246*3e777be0SXin Li 
247*3e777be0SXin Li     //Inputs:
248*3e777be0SXin Li     // 0: The input: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBatches, inputSize]
249*3e777be0SXin Li     //    specifying the input to the LSTM cell. Tensor is quantized with a fixed quantization range of -1, 127/128.
250*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
251*3e777be0SXin Li     if (!input.IsValid())
252*3e777be0SXin Li     {
253*3e777be0SXin Li         return Fail("%s: Could not read input 0: input", __func__);
254*3e777be0SXin Li     }
255*3e777be0SXin Li 
256*3e777be0SXin Li     // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, of shape [batch_size, output_size].
257*3e777be0SXin Li     LayerInputHandle outputStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
258*3e777be0SXin Li     if (!outputStatePrevTimeStep.IsValid())
259*3e777be0SXin Li     {
260*3e777be0SXin Li         return Fail("%s: Could not read input 18: outputStatePrevTimeStep", __func__);
261*3e777be0SXin Li     }
262*3e777be0SXin Li 
263*3e777be0SXin Li     // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
264*3e777be0SXin Li     LayerInputHandle cellStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
265*3e777be0SXin Li     if (!cellStatePrevTimeStep.IsValid())
266*3e777be0SXin Li     {
267*3e777be0SXin Li         return Fail("%s: Could not read input 19: cellStatePrevTimeStep", __func__);
268*3e777be0SXin Li     }
269*3e777be0SXin Li 
270*3e777be0SXin Li     // Get the mandatory input tensors:
271*3e777be0SXin Li 
272*3e777be0SXin Li     // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
273*3e777be0SXin Li     //     [num_units, input_size].
274*3e777be0SXin Li     const ConstTensorPin inputToForgetWeightsPin =
275*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
276*3e777be0SXin Li 
277*3e777be0SXin Li     // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
278*3e777be0SXin Li     // [num_units, input_size].
279*3e777be0SXin Li     const ConstTensorPin inputToCellWeightsPin =
280*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 3, model, data);
281*3e777be0SXin Li 
282*3e777be0SXin Li     // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
283*3e777be0SXin Li     //     [num_units, input_size].
284*3e777be0SXin Li     const ConstTensorPin inputToOutputWeightsPin =
285*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 4, model, data);
286*3e777be0SXin Li 
287*3e777be0SXin Li     // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
288*3e777be0SXin Li     //     [num_units, output_size].
289*3e777be0SXin Li     const ConstTensorPin recurrentToForgetWeightsPin =
290*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 6, model, data);
291*3e777be0SXin Li 
292*3e777be0SXin Li     // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
293*3e777be0SXin Li     //     [num_units, output_size].
294*3e777be0SXin Li     const ConstTensorPin recurrentToCellWeightsPin =
295*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 7, model, data);
296*3e777be0SXin Li 
297*3e777be0SXin Li     // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
298*3e777be0SXin Li     //     [num_units, output_size].
299*3e777be0SXin Li     const ConstTensorPin recurrentToOutputWeightsPin =
300*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 8, model, data);
301*3e777be0SXin Li 
302*3e777be0SXin Li     // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
303*3e777be0SXin Li     const ConstTensorPin forgetGateBiasPin =
304*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
305*3e777be0SXin Li 
306*3e777be0SXin Li     // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
307*3e777be0SXin Li     const ConstTensorPin cellBiasPin =
308*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
309*3e777be0SXin Li 
310*3e777be0SXin Li     // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
311*3e777be0SXin Li     const ConstTensorPin outputGateBiasPin =
312*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
313*3e777be0SXin Li 
314*3e777be0SXin Li     if (!inputToForgetWeightsPin.IsValid() ||
315*3e777be0SXin Li         !inputToCellWeightsPin.IsValid() ||
316*3e777be0SXin Li         !inputToOutputWeightsPin.IsValid() ||
317*3e777be0SXin Li         !recurrentToForgetWeightsPin.IsValid() ||
318*3e777be0SXin Li         !recurrentToCellWeightsPin.IsValid() ||
319*3e777be0SXin Li         !recurrentToOutputWeightsPin.IsValid() ||
320*3e777be0SXin Li         !forgetGateBiasPin.IsValid() ||
321*3e777be0SXin Li         !cellBiasPin.IsValid() ||
322*3e777be0SXin Li         !outputGateBiasPin.IsValid())
323*3e777be0SXin Li     {
324*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
325*3e777be0SXin Li     }
326*3e777be0SXin Li 
327*3e777be0SXin Li     // Get the optional input tensors:
328*3e777be0SXin Li 
329*3e777be0SXin Li     // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
330*3e777be0SXin Li     //     [num_units, input_size], where “num_units” corresponds to the number of cell units.
331*3e777be0SXin Li     const ConstTensorPin inputToInputWeightsPin =
332*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
333*3e777be0SXin Li                                                          1,
334*3e777be0SXin Li                                                          model,
335*3e777be0SXin Li                                                          data,
336*3e777be0SXin Li                                                          g_DontPermute,
337*3e777be0SXin Li                                                          nullptr,
338*3e777be0SXin Li                                                          true);
339*3e777be0SXin Li 
340*3e777be0SXin Li     // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
341*3e777be0SXin Li     //     [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
342*3e777be0SXin Li     //     “num_units”), or the second dimension of the “projection_weights”, if defined.
343*3e777be0SXin Li     const ConstTensorPin recurrentToInputWeightsPin =
344*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
345*3e777be0SXin Li                                                          5,
346*3e777be0SXin Li                                                          model,
347*3e777be0SXin Li                                                          data,
348*3e777be0SXin Li                                                          g_DontPermute,
349*3e777be0SXin Li                                                          nullptr,
350*3e777be0SXin Li                                                          true);
351*3e777be0SXin Li 
352*3e777be0SXin Li     // 09: The cell-to-input weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
353*3e777be0SXin Li     // [num_units].
354*3e777be0SXin Li     const ConstTensorPin cellToInputWeightsPin =
355*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
356*3e777be0SXin Li                                                          9,
357*3e777be0SXin Li                                                          model,
358*3e777be0SXin Li                                                          data,
359*3e777be0SXin Li                                                          g_DontPermute,
360*3e777be0SXin Li                                                          nullptr,
361*3e777be0SXin Li                                                          true);
362*3e777be0SXin Li 
363*3e777be0SXin Li     // 10: The cell-to-forget weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
364*3e777be0SXin Li     // [num_units].
365*3e777be0SXin Li     const ConstTensorPin cellToForgetWeightsPin =
366*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
367*3e777be0SXin Li                                                          10,
368*3e777be0SXin Li                                                          model,
369*3e777be0SXin Li                                                          data,
370*3e777be0SXin Li                                                          g_DontPermute,
371*3e777be0SXin Li                                                          nullptr,
372*3e777be0SXin Li                                                          true);
373*3e777be0SXin Li 
374*3e777be0SXin Li     // 11: The cell-to-output weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
375*3e777be0SXin Li     // [num_units].
376*3e777be0SXin Li     const ConstTensorPin cellToOutputWeightsPin =
377*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
378*3e777be0SXin Li                                                          11,
379*3e777be0SXin Li                                                          model,
380*3e777be0SXin Li                                                          data,
381*3e777be0SXin Li                                                          g_DontPermute,
382*3e777be0SXin Li                                                          nullptr,
383*3e777be0SXin Li                                                          true);
384*3e777be0SXin Li 
385*3e777be0SXin Li     // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
386*3e777be0SXin Li     const ConstTensorPin inputGateBiasPin =
387*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
388*3e777be0SXin Li                                                          12,
389*3e777be0SXin Li                                                          model,
390*3e777be0SXin Li                                                          data,
391*3e777be0SXin Li                                                          g_DontPermute,
392*3e777be0SXin Li                                                          nullptr,
393*3e777be0SXin Li                                                          true);
394*3e777be0SXin Li 
395*3e777be0SXin Li     // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
396*3e777be0SXin Li     //     [output_size, num_units].
397*3e777be0SXin Li     const ConstTensorPin projectionWeightsPin =
398*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
399*3e777be0SXin Li                                                          16,
400*3e777be0SXin Li                                                          model,
401*3e777be0SXin Li                                                          data,
402*3e777be0SXin Li                                                          g_DontPermute,
403*3e777be0SXin Li                                                          nullptr,
404*3e777be0SXin Li                                                          true);
405*3e777be0SXin Li 
406*3e777be0SXin Li     // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [output_size].
407*3e777be0SXin Li     const ConstTensorPin projectionBiasPin =
408*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
409*3e777be0SXin Li                                                          17,
410*3e777be0SXin Li                                                          model,
411*3e777be0SXin Li                                                          data,
412*3e777be0SXin Li                                                          g_DontPermute,
413*3e777be0SXin Li                                                          nullptr,
414*3e777be0SXin Li                                                          true);
415*3e777be0SXin Li 
416*3e777be0SXin Li     if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional())
417*3e777be0SXin Li         || (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional())
418*3e777be0SXin Li         || (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional())
419*3e777be0SXin Li         || (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional())
420*3e777be0SXin Li         || (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional())
421*3e777be0SXin Li         || (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional())
422*3e777be0SXin Li         || (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional())
423*3e777be0SXin Li         || (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
424*3e777be0SXin Li     {
425*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
426*3e777be0SXin Li     }
427*3e777be0SXin Li 
428*3e777be0SXin Li 
429*3e777be0SXin Li     // Get the optional normalization tensors
430*3e777be0SXin Li 
431*3e777be0SXin Li     // 20: The input layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
432*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at input gate.
433*3e777be0SXin Li     const ConstTensorPin inputLayerNormWeightsPin =
434*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
435*3e777be0SXin Li                                                          20,
436*3e777be0SXin Li                                                          model,
437*3e777be0SXin Li                                                          data,
438*3e777be0SXin Li                                                          g_DontPermute,
439*3e777be0SXin Li                                                          nullptr,
440*3e777be0SXin Li                                                          true);
441*3e777be0SXin Li 
442*3e777be0SXin Li     // 21: The forget layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM
443*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at forget gate.
444*3e777be0SXin Li     const ConstTensorPin forgetLayerNormWeightsPin =
445*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
446*3e777be0SXin Li                                                          21,
447*3e777be0SXin Li                                                          model,
448*3e777be0SXin Li                                                          data,
449*3e777be0SXin Li                                                          g_DontPermute,
450*3e777be0SXin Li                                                          nullptr,
451*3e777be0SXin Li                                                          true);
452*3e777be0SXin Li 
453*3e777be0SXin Li     // 22: The cell layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
454*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at cell gate.
455*3e777be0SXin Li     const ConstTensorPin cellLayerNormWeightsPin =
456*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
457*3e777be0SXin Li                                                          22,
458*3e777be0SXin Li                                                          model,
459*3e777be0SXin Li                                                          data,
460*3e777be0SXin Li                                                          g_DontPermute,
461*3e777be0SXin Li                                                          nullptr,
462*3e777be0SXin Li                                                          true);
463*3e777be0SXin Li 
464*3e777be0SXin Li     // 23: The output layer normalization weights. A 1-D tensor of shape [num_units].
465*3e777be0SXin Li     //     Used to rescale normalized inputs to activation at output gate.
466*3e777be0SXin Li     const ConstTensorPin outputLayerNormWeightsPin =
467*3e777be0SXin Li         ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
468*3e777be0SXin Li                                                          23,
469*3e777be0SXin Li                                                          model,
470*3e777be0SXin Li                                                          data,
471*3e777be0SXin Li                                                          g_DontPermute,
472*3e777be0SXin Li                                                          nullptr,
473*3e777be0SXin Li                                                          true);
474*3e777be0SXin Li 
475*3e777be0SXin Li     if ((!inputLayerNormWeightsPin.IsValid() && !inputLayerNormWeightsPin.IsOptional())
476*3e777be0SXin Li         || (!forgetLayerNormWeightsPin.IsValid() && !forgetLayerNormWeightsPin.IsOptional())
477*3e777be0SXin Li         || (!cellLayerNormWeightsPin.IsValid() && !cellLayerNormWeightsPin.IsOptional())
478*3e777be0SXin Li         || (!outputLayerNormWeightsPin.IsValid() && !outputLayerNormWeightsPin.IsOptional()))
479*3e777be0SXin Li     {
480*3e777be0SXin Li         return Fail("%s: Operation has invalid tensor inputs", __func__);
481*3e777be0SXin Li     }
482*3e777be0SXin Li 
483*3e777be0SXin Li     // Get the optional input scalars:
484*3e777be0SXin Li     // 24: The cell clip:  If provided the cell state is clipped by this value prior to the cell output activation.
485*3e777be0SXin Li     // 25: The projection clip: If provided and projection is enabled, this is used for clipping the projected values.
486*3e777be0SXin Li 
487*3e777be0SXin Li     // Get the mandatory input scalars:
488*3e777be0SXin Li     // 26: The scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
489*3e777be0SXin Li     // 27: The scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
490*3e777be0SXin Li     // 28: The scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
491*3e777be0SXin Li     // 29: The scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
492*3e777be0SXin Li     // 30: The zero point of the hidden state, i.e. input to projection.
493*3e777be0SXin Li     // 31: The scale of the hidden state, i.e. input to projection.
494*3e777be0SXin Li     float cellClip, projClip, matMulInputGate, matMulForgetGate, matMulCellGate, matMulOutputGate, projInputScale;
495*3e777be0SXin Li     int projInputZeroPoint;
496*3e777be0SXin Li 
497*3e777be0SXin Li     if (!GetInputScalar<HalPolicy>(operation, 24, HalOperandType::FLOAT32, cellClip, model, data, true) ||
498*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 25, HalOperandType::FLOAT32, projClip, model, data, true) ||
499*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 26, HalOperandType::FLOAT32, matMulInputGate, model, data) ||
500*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 27, HalOperandType::FLOAT32, matMulForgetGate, model, data) ||
501*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 28, HalOperandType::FLOAT32, matMulCellGate, model, data) ||
502*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 29, HalOperandType::FLOAT32, matMulOutputGate, model, data) ||
503*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 30, HalOperandType::INT32, projInputZeroPoint, model, data) ||
504*3e777be0SXin Li         !GetInputScalar<HalPolicy>(operation, 31, HalOperandType::FLOAT32, projInputScale, model, data))
505*3e777be0SXin Li     {
506*3e777be0SXin Li         return Fail("%s: Operation has invalid scalar inputs", __func__);
507*3e777be0SXin Li     }
508*3e777be0SXin Li 
509*3e777be0SXin Li     // Outputs:
510*3e777be0SXin Li     // 0: The output state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size,
511*3e777be0SXin Li     // output_size].
512*3e777be0SXin Li     const HalOperand* outputStateOut = GetOutputOperand<HalPolicy>(operation, 0, model);
513*3e777be0SXin Li     if (!outputStateOut)
514*3e777be0SXin Li     {
515*3e777be0SXin Li         return Fail("%s: Could not read output 0: outputStateOut", __func__);
516*3e777be0SXin Li     }
517*3e777be0SXin Li 
518*3e777be0SXin Li     // 1: The cell state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
519*3e777be0SXin Li     const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 1, model);
520*3e777be0SXin Li     if (!cellStateOut)
521*3e777be0SXin Li     {
522*3e777be0SXin Li         return Fail("%s: Could not read output 1: cellStateOut", __func__);
523*3e777be0SXin Li     }
524*3e777be0SXin Li 
525*3e777be0SXin Li     // 2: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size, output_size].
526*3e777be0SXin Li     // This is effectively the same as the current “output state (out)” value.
527*3e777be0SXin Li     const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 2, model);
528*3e777be0SXin Li     if (!output)
529*3e777be0SXin Li     {
530*3e777be0SXin Li         return Fail("%s: Could not read output 2: output", __func__);
531*3e777be0SXin Li     }
532*3e777be0SXin Li 
533*3e777be0SXin Li     // set the params structure for the AddLstmLayer call
534*3e777be0SXin Li     LstmInputParams params;
535*3e777be0SXin Li     params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
536*3e777be0SXin Li     params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
537*3e777be0SXin Li     params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
538*3e777be0SXin Li     params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
539*3e777be0SXin Li     params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
540*3e777be0SXin Li     params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
541*3e777be0SXin Li     params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
542*3e777be0SXin Li     params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
543*3e777be0SXin Li     params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
544*3e777be0SXin Li     params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
545*3e777be0SXin Li     params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
546*3e777be0SXin Li     params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
547*3e777be0SXin Li     params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
548*3e777be0SXin Li     params.m_CellBias = cellBiasPin.GetConstTensorPtr();
549*3e777be0SXin Li     params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
550*3e777be0SXin Li     params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
551*3e777be0SXin Li     params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
552*3e777be0SXin Li     params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
553*3e777be0SXin Li     params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
554*3e777be0SXin Li     params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
555*3e777be0SXin Li     params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
556*3e777be0SXin Li 
557*3e777be0SXin Li     // set the layer descriptor
558*3e777be0SXin Li     QLstmDescriptor desc;
559*3e777be0SXin Li     desc.m_CellClip = cellClip;
560*3e777be0SXin Li     desc.m_ProjectionClip = projClip;
561*3e777be0SXin Li     desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
562*3e777be0SXin Li                           params.m_RecurrentToInputWeights == nullptr ||
563*3e777be0SXin Li                           params.m_InputGateBias == nullptr);
564*3e777be0SXin Li     desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
565*3e777be0SXin Li                               params.m_CellToOutputWeights != nullptr);
566*3e777be0SXin Li     desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
567*3e777be0SXin Li     desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
568*3e777be0SXin Li                                params.m_ForgetLayerNormWeights != nullptr ||
569*3e777be0SXin Li                                params.m_CellLayerNormWeights != nullptr ||
570*3e777be0SXin Li                                params.m_OutputLayerNormWeights != nullptr);
571*3e777be0SXin Li     desc.m_InputIntermediateScale = matMulInputGate;
572*3e777be0SXin Li     desc.m_ForgetIntermediateScale = matMulForgetGate;
573*3e777be0SXin Li     desc.m_CellIntermediateScale = matMulCellGate;
574*3e777be0SXin Li     desc.m_OutputIntermediateScale = matMulOutputGate;
575*3e777be0SXin Li     desc.m_HiddenStateScale = projInputScale;
576*3e777be0SXin Li     desc.m_HiddenStateZeroPoint = projInputZeroPoint;
577*3e777be0SXin Li 
578*3e777be0SXin Li     // validate the optional input groups
579*3e777be0SXin Li     if (desc.m_CifgEnabled &&
580*3e777be0SXin Li         (params.m_InputToInputWeights != nullptr ||
581*3e777be0SXin Li          params.m_RecurrentToInputWeights != nullptr ||
582*3e777be0SXin Li          params.m_InputGateBias != nullptr))
583*3e777be0SXin Li     {
584*3e777be0SXin Li         return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
585*3e777be0SXin Li                     " and input gate bias must be provided", __func__);
586*3e777be0SXin Li     }
587*3e777be0SXin Li 
588*3e777be0SXin Li     if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
589*3e777be0SXin Li     {
590*3e777be0SXin Li         return Fail("%s: projection bias should not be provided without projection weights", __func__);
591*3e777be0SXin Li     }
592*3e777be0SXin Li 
593*3e777be0SXin Li     if (desc.m_PeepholeEnabled &&
594*3e777be0SXin Li         (params.m_CellToForgetWeights == nullptr ||
595*3e777be0SXin Li          params.m_CellToOutputWeights == nullptr ||
596*3e777be0SXin Li          (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
597*3e777be0SXin Li     {
598*3e777be0SXin Li         return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
599*3e777be0SXin Li                     " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
600*3e777be0SXin Li     }
601*3e777be0SXin Li 
602*3e777be0SXin Li     if (desc.m_LayerNormEnabled &&
603*3e777be0SXin Li         (params.m_ForgetLayerNormWeights == nullptr ||
604*3e777be0SXin Li          params.m_CellLayerNormWeights == nullptr ||
605*3e777be0SXin Li          params.m_OutputLayerNormWeights == nullptr ||
606*3e777be0SXin Li          (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
607*3e777be0SXin Li     {
608*3e777be0SXin Li         return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
609*3e777be0SXin Li                     " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
610*3e777be0SXin Li     }
611*3e777be0SXin Li 
612*3e777be0SXin Li 
613*3e777be0SXin Li     // Basic parameters
614*3e777be0SXin Li     LstmInputParamsInfo paramsInfo;
615*3e777be0SXin Li     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
616*3e777be0SXin Li     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
617*3e777be0SXin Li     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
618*3e777be0SXin Li     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
619*3e777be0SXin Li     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
620*3e777be0SXin Li     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
621*3e777be0SXin Li     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
622*3e777be0SXin Li     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
623*3e777be0SXin Li     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
624*3e777be0SXin Li 
625*3e777be0SXin Li     // Inputs
626*3e777be0SXin Li     const TensorInfo& inputInfo = input.GetTensorInfo();
627*3e777be0SXin Li     const TensorInfo& outputStatePrevTimeStepInfo = outputStatePrevTimeStep.GetTensorInfo();
628*3e777be0SXin Li     const TensorInfo& cellStatePrevTimeStepInfo = cellStatePrevTimeStep.GetTensorInfo();
629*3e777be0SXin Li 
630*3e777be0SXin Li     // Outputs
631*3e777be0SXin Li     TensorInfo outputStateOutInfo = GetTensorInfoForOperand(*outputStateOut);
632*3e777be0SXin Li     TensorInfo outputInfo = GetTensorInfoForOperand(*output);
633*3e777be0SXin Li     const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
634*3e777be0SXin Li 
635*3e777be0SXin Li     // Optional parameters
636*3e777be0SXin Li     if (!desc.m_CifgEnabled)
637*3e777be0SXin Li     {
638*3e777be0SXin Li         paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
639*3e777be0SXin Li         paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
640*3e777be0SXin Li         if (desc.m_PeepholeEnabled)
641*3e777be0SXin Li         {
642*3e777be0SXin Li             paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
643*3e777be0SXin Li         }
644*3e777be0SXin Li         paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
645*3e777be0SXin Li     }
646*3e777be0SXin Li 
647*3e777be0SXin Li 
648*3e777be0SXin Li     if (desc.m_ProjectionEnabled)
649*3e777be0SXin Li     {
650*3e777be0SXin Li         paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
651*3e777be0SXin Li         if (params.m_ProjectionBias != nullptr)
652*3e777be0SXin Li         {
653*3e777be0SXin Li             paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
654*3e777be0SXin Li         }
655*3e777be0SXin Li     }
656*3e777be0SXin Li     else
657*3e777be0SXin Li     {
658*3e777be0SXin Li         // If Projection is disabled, override non-const outputs to change the quant info with hidden params, then
659*3e777be0SXin Li         // create a new const TensorInfo based on this
660*3e777be0SXin Li         outputStateOutInfo.SetQuantizationScale(projInputScale);
661*3e777be0SXin Li         outputStateOutInfo.SetQuantizationOffset(projInputZeroPoint);
662*3e777be0SXin Li         outputInfo.SetQuantizationScale(projInputScale);
663*3e777be0SXin Li         outputInfo.SetQuantizationOffset(projInputZeroPoint);
664*3e777be0SXin Li     }
665*3e777be0SXin Li 
666*3e777be0SXin Li     const TensorInfo constOutputStateOutInfo(outputStateOutInfo);
667*3e777be0SXin Li     const TensorInfo constOutputInfo(outputInfo);
668*3e777be0SXin Li 
669*3e777be0SXin Li     if (desc.m_PeepholeEnabled)
670*3e777be0SXin Li     {
671*3e777be0SXin Li         paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
672*3e777be0SXin Li         paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
673*3e777be0SXin Li     }
674*3e777be0SXin Li 
675*3e777be0SXin Li     if (desc.m_LayerNormEnabled)
676*3e777be0SXin Li     {
677*3e777be0SXin Li         if(!desc.m_CifgEnabled)
678*3e777be0SXin Li         {
679*3e777be0SXin Li             paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
680*3e777be0SXin Li         }
681*3e777be0SXin Li         paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
682*3e777be0SXin Li         paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
683*3e777be0SXin Li         paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
684*3e777be0SXin Li     }
685*3e777be0SXin Li 
686*3e777be0SXin Li     // Check if the layer is supported
687*3e777be0SXin Li     bool isSupported = false;
688*3e777be0SXin Li     armnn::BackendId setBackend;
689*3e777be0SXin Li     auto validateFunc = [&](const armnn::TensorInfo& cellStateOutInfo, bool& isSupported)
690*3e777be0SXin Li     {
691*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
692*3e777be0SXin Li                                    IsQLstmSupported,
693*3e777be0SXin Li                                    data.m_Backends,
694*3e777be0SXin Li                                    isSupported,
695*3e777be0SXin Li                                    setBackend,
696*3e777be0SXin Li                                    inputInfo,
697*3e777be0SXin Li                                    outputStatePrevTimeStepInfo,
698*3e777be0SXin Li                                    cellStatePrevTimeStepInfo,
699*3e777be0SXin Li                                    constOutputStateOutInfo,
700*3e777be0SXin Li                                    cellStateOutInfo,
701*3e777be0SXin Li                                    constOutputInfo,
702*3e777be0SXin Li                                    desc,
703*3e777be0SXin Li                                    paramsInfo);
704*3e777be0SXin Li     };
705*3e777be0SXin Li 
706*3e777be0SXin Li     bool isDynamic = false;
707*3e777be0SXin Li     if (!IsDynamicTensor(constOutputStateOutInfo) &&
708*3e777be0SXin Li         !IsDynamicTensor(cellStateOutInfo)  &&
709*3e777be0SXin Li         !IsDynamicTensor(constOutputInfo))
710*3e777be0SXin Li     {
711*3e777be0SXin Li         validateFunc(outputInfo, isSupported);
712*3e777be0SXin Li     }
713*3e777be0SXin Li     else
714*3e777be0SXin Li     {
715*3e777be0SXin Li         isDynamic = true;
716*3e777be0SXin Li         isSupported = AreDynamicTensorsSupported();
717*3e777be0SXin Li     }
718*3e777be0SXin Li 
719*3e777be0SXin Li     if (!isSupported)
720*3e777be0SXin Li     {
721*3e777be0SXin Li         return false;
722*3e777be0SXin Li     }
723*3e777be0SXin Li 
724*3e777be0SXin Li     // Add the layer
725*3e777be0SXin Li     IConnectableLayer* layer = data.m_Network->AddQLstmLayer(desc, params, "QLstm");
726*3e777be0SXin Li     layer->SetBackendId(setBackend);
727*3e777be0SXin Li 
728*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
729*3e777be0SXin Li     outputStatePrevTimeStep.Connect(layer->GetInputSlot(1));
730*3e777be0SXin Li     cellStatePrevTimeStep.Connect(layer->GetInputSlot(2));
731*3e777be0SXin Li 
732*3e777be0SXin Li     if (!isDynamic)
733*3e777be0SXin Li     {
734*3e777be0SXin Li         return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
735*3e777be0SXin Li                        operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
736*3e777be0SXin Li                  SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
737*3e777be0SXin Li                  SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
738*3e777be0SXin Li     }
739*3e777be0SXin Li     else
740*3e777be0SXin Li     {
741*3e777be0SXin Li         return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
742*3e777be0SXin Li                        operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
743*3e777be0SXin Li                  SetupAndTrackLayerOutputSlot<HalPolicy>(
744*3e777be0SXin Li                        operation, 1, *layer, 1, model, data, nullptr, validateFunc,
745*3e777be0SXin Li                        ActivationFn::kActivationNone, true) &&
746*3e777be0SXin Li                  SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
747*3e777be0SXin Li     }
748*3e777be0SXin Li }
749*3e777be0SXin Li 
750*3e777be0SXin Li template<typename HalPolicy,
751*3e777be0SXin Li          typename HalOperation = typename HalPolicy::Operation,
752*3e777be0SXin Li          typename HalModel     = typename HalPolicy::Model>
ConvertRank(const HalOperation & operation,const HalModel & model,ConversionData & data)753*3e777be0SXin Li bool ConvertRank(const HalOperation& operation, const HalModel& model, ConversionData& data)
754*3e777be0SXin Li {
755*3e777be0SXin Li     using HalOperand = typename HalPolicy::Operand;
756*3e777be0SXin Li 
757*3e777be0SXin Li     const HalOperand* inputOperand = GetInputOperand<HalPolicy>(operation, 0, model);
758*3e777be0SXin Li     const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
759*3e777be0SXin Li 
760*3e777be0SXin Li     if (inputOperand == nullptr || outputOperand == nullptr)
761*3e777be0SXin Li     {
762*3e777be0SXin Li         return Fail("%s: Operation has invalid inputs", __func__);
763*3e777be0SXin Li     }
764*3e777be0SXin Li 
765*3e777be0SXin Li     const Shape inputOperandShape = GetOperandShape(*inputOperand);
766*3e777be0SXin Li     const Shape outputOperandShape = GetOperandShape(*outputOperand);
767*3e777be0SXin Li 
768*3e777be0SXin Li     LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
769*3e777be0SXin Li     if (!input.IsValid())
770*3e777be0SXin Li     {
771*3e777be0SXin Li         return Fail("%s: Could not read input 0", __func__);
772*3e777be0SXin Li     }
773*3e777be0SXin Li 
774*3e777be0SXin Li     armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand);
775*3e777be0SXin Li     if (IsDynamicTensor(outInfo))
776*3e777be0SXin Li     {
777*3e777be0SXin Li         return Fail("%s: Dynamic output tensors are not supported", __func__);
778*3e777be0SXin Li     }
779*3e777be0SXin Li 
780*3e777be0SXin Li     bool isSupported = false;
781*3e777be0SXin Li     armnn::BackendId setBackend;
782*3e777be0SXin Li     FORWARD_LAYER_SUPPORT_FUNC(__func__,
783*3e777be0SXin Li                                IsRankSupported,
784*3e777be0SXin Li                                data.m_Backends,
785*3e777be0SXin Li                                isSupported,
786*3e777be0SXin Li                                setBackend,
787*3e777be0SXin Li                                input.GetTensorInfo(),
788*3e777be0SXin Li                                outInfo);
789*3e777be0SXin Li     if (!isSupported)
790*3e777be0SXin Li     {
791*3e777be0SXin Li         return false;
792*3e777be0SXin Li     }
793*3e777be0SXin Li 
794*3e777be0SXin Li     armnn::IConnectableLayer* layer = data.m_Network->AddRankLayer();
795*3e777be0SXin Li     layer->SetBackendId(setBackend);
796*3e777be0SXin Li     if (!layer)
797*3e777be0SXin Li     {
798*3e777be0SXin Li         return Fail("%s: Could not add the RankLayer", __func__);
799*3e777be0SXin Li     }
800*3e777be0SXin Li     input.Connect(layer->GetInputSlot(0));
801*3e777be0SXin Li 
802*3e777be0SXin Li     return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, &outInfo);
803*3e777be0SXin Li }
804*3e777be0SXin Li 
805*3e777be0SXin Li } // armnn_driver namespace
806