1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2020,2022 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li
8*3e777be0SXin Li #include "ConversionUtils_1_2.hpp"
9*3e777be0SXin Li
10*3e777be0SXin Li using Half = half_float::half;
11*3e777be0SXin Li
12*3e777be0SXin Li namespace armnn_driver
13*3e777be0SXin Li {
14*3e777be0SXin Li
15*3e777be0SXin Li using namespace armnn;
16*3e777be0SXin Li using namespace android::nn;
17*3e777be0SXin Li
18*3e777be0SXin Li template<typename HalPolicy,
19*3e777be0SXin Li typename HalOperation = typename HalPolicy::Operation,
20*3e777be0SXin Li typename HalModel = typename HalPolicy::Model>
ConvertElu(const HalOperation & operation,const HalModel & model,ConversionData & data)21*3e777be0SXin Li bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
22*3e777be0SXin Li {
23*3e777be0SXin Li using HalOperandType = typename HalPolicy::OperandType;
24*3e777be0SXin Li
25*3e777be0SXin Li LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
26*3e777be0SXin Li if (!input0.IsValid())
27*3e777be0SXin Li {
28*3e777be0SXin Li return Fail("%s: Operation has invalid inputs", __func__);
29*3e777be0SXin Li }
30*3e777be0SXin Li
31*3e777be0SXin Li // Determine data type of input tensor
32*3e777be0SXin Li HalOperandType inputType;
33*3e777be0SXin Li if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
34*3e777be0SXin Li {
35*3e777be0SXin Li return Fail("%s: Operation has invalid inputs", __func__);
36*3e777be0SXin Li }
37*3e777be0SXin Li
38*3e777be0SXin Li ActivationDescriptor desc;
39*3e777be0SXin Li desc.m_Function = ActivationFunction::Elu;
40*3e777be0SXin Li
41*3e777be0SXin Li // Read alpha
42*3e777be0SXin Li if (inputType == HalOperandType::TENSOR_FLOAT16)
43*3e777be0SXin Li {
44*3e777be0SXin Li Half alpha;
45*3e777be0SXin Li
46*3e777be0SXin Li if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
47*3e777be0SXin Li {
48*3e777be0SXin Li return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
49*3e777be0SXin Li }
50*3e777be0SXin Li
51*3e777be0SXin Li desc.m_A = static_cast<float>(alpha);
52*3e777be0SXin Li }
53*3e777be0SXin Li else if (inputType == HalOperandType::TENSOR_FLOAT32)
54*3e777be0SXin Li {
55*3e777be0SXin Li if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
56*3e777be0SXin Li {
57*3e777be0SXin Li return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
58*3e777be0SXin Li }
59*3e777be0SXin Li }
60*3e777be0SXin Li else
61*3e777be0SXin Li {
62*3e777be0SXin Li return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
63*3e777be0SXin Li }
64*3e777be0SXin Li
65*3e777be0SXin Li return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
66*3e777be0SXin Li }
67*3e777be0SXin Li
68*3e777be0SXin Li template<typename HalPolicy,
69*3e777be0SXin Li typename HalOperation = typename HalPolicy::Operation,
70*3e777be0SXin Li typename HalModel = typename HalPolicy::Model>
ConvertFill(const HalOperation & operation,const HalModel & model,ConversionData & data)71*3e777be0SXin Li bool ConvertFill(const HalOperation& operation, const HalModel& model, ConversionData& data)
72*3e777be0SXin Li {
73*3e777be0SXin Li using HalOperand = typename HalPolicy::Operand;
74*3e777be0SXin Li using HalOperandType = typename HalPolicy::OperandType;
75*3e777be0SXin Li
76*3e777be0SXin Li LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
77*3e777be0SXin Li if (!input.IsValid())
78*3e777be0SXin Li {
79*3e777be0SXin Li return Fail("%s: Operation has invalid inputs", __func__);
80*3e777be0SXin Li }
81*3e777be0SXin Li
82*3e777be0SXin Li const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
83*3e777be0SXin Li if (!output)
84*3e777be0SXin Li {
85*3e777be0SXin Li return Fail("%s: Could not read output", __func__);
86*3e777be0SXin Li }
87*3e777be0SXin Li
88*3e777be0SXin Li const TensorInfo& inputInfo = input.GetTensorInfo();
89*3e777be0SXin Li const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
90*3e777be0SXin Li if (IsDynamicTensor(outputInfo))
91*3e777be0SXin Li {
92*3e777be0SXin Li return Fail("%s: Dynamic output tensors are not supported", __func__);
93*3e777be0SXin Li }
94*3e777be0SXin Li
95*3e777be0SXin Li // Determine data type of output tensor
96*3e777be0SXin Li HalOperandType outputType = output->type;
97*3e777be0SXin Li FillDescriptor descriptor;
98*3e777be0SXin Li // Read the scalar fill value
99*3e777be0SXin Li if (outputType == HalOperandType::TENSOR_FLOAT16)
100*3e777be0SXin Li {
101*3e777be0SXin Li Half value;
102*3e777be0SXin Li
103*3e777be0SXin Li if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
104*3e777be0SXin Li {
105*3e777be0SXin Li return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
106*3e777be0SXin Li }
107*3e777be0SXin Li
108*3e777be0SXin Li descriptor.m_Value = static_cast<float>(value);
109*3e777be0SXin Li }
110*3e777be0SXin Li else if (outputType == HalOperandType::TENSOR_FLOAT32)
111*3e777be0SXin Li {
112*3e777be0SXin Li if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, descriptor.m_Value, model, data))
113*3e777be0SXin Li {
114*3e777be0SXin Li return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
115*3e777be0SXin Li }
116*3e777be0SXin Li }
117*3e777be0SXin Li else if (outputType == HalOperandType::TENSOR_INT32)
118*3e777be0SXin Li {
119*3e777be0SXin Li int32_t value;
120*3e777be0SXin Li
121*3e777be0SXin Li if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, value, model, data))
122*3e777be0SXin Li {
123*3e777be0SXin Li return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
124*3e777be0SXin Li }
125*3e777be0SXin Li
126*3e777be0SXin Li descriptor.m_Value = static_cast<float>(value);
127*3e777be0SXin Li }
128*3e777be0SXin Li else
129*3e777be0SXin Li {
130*3e777be0SXin Li return Fail("%s: Unsupported input tensor type: %d", __func__, outputType);
131*3e777be0SXin Li }
132*3e777be0SXin Li
133*3e777be0SXin Li bool isSupported = false;
134*3e777be0SXin Li armnn::BackendId setBackend;
135*3e777be0SXin Li FORWARD_LAYER_SUPPORT_FUNC(__func__,
136*3e777be0SXin Li IsFillSupported,
137*3e777be0SXin Li data.m_Backends,
138*3e777be0SXin Li isSupported,
139*3e777be0SXin Li setBackend,
140*3e777be0SXin Li inputInfo,
141*3e777be0SXin Li outputInfo,
142*3e777be0SXin Li descriptor);
143*3e777be0SXin Li if (!isSupported)
144*3e777be0SXin Li {
145*3e777be0SXin Li return false;
146*3e777be0SXin Li }
147*3e777be0SXin Li
148*3e777be0SXin Li IConnectableLayer* const layer = data.m_Network->AddFillLayer(descriptor);
149*3e777be0SXin Li layer->SetBackendId(setBackend);
150*3e777be0SXin Li if (!layer)
151*3e777be0SXin Li {
152*3e777be0SXin Li return Fail("%s: Could not add the FillLayer", __func__);
153*3e777be0SXin Li }
154*3e777be0SXin Li input.Connect(layer->GetInputSlot(0));
155*3e777be0SXin Li
156*3e777be0SXin Li return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
157*3e777be0SXin Li }
158*3e777be0SXin Li
159*3e777be0SXin Li template<typename HalPolicy,
160*3e777be0SXin Li typename HalOperation = typename HalPolicy::Operation,
161*3e777be0SXin Li typename HalModel = typename HalPolicy::Model>
ConvertLogicalBinary(const HalOperation & operation,const HalModel & model,ConversionData & data,LogicalBinaryOperation logicalOperation)162*3e777be0SXin Li bool ConvertLogicalBinary(const HalOperation& operation,
163*3e777be0SXin Li const HalModel& model,
164*3e777be0SXin Li ConversionData& data,
165*3e777be0SXin Li LogicalBinaryOperation logicalOperation)
166*3e777be0SXin Li {
167*3e777be0SXin Li using HalOperand = typename HalPolicy::Operand;
168*3e777be0SXin Li
169*3e777be0SXin Li ALOGV("HalPolicy::ConvertLogicalBinary()");
170*3e777be0SXin Li ALOGV("logicalOperation = %s", GetLogicalBinaryOperationAsCString(logicalOperation));
171*3e777be0SXin Li
172*3e777be0SXin Li LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
173*3e777be0SXin Li LayerInputHandle input1 = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
174*3e777be0SXin Li
175*3e777be0SXin Li if (!(input0.IsValid() && input1.IsValid()))
176*3e777be0SXin Li {
177*3e777be0SXin Li return Fail("%s: Operation has invalid inputs", __func__);
178*3e777be0SXin Li }
179*3e777be0SXin Li
180*3e777be0SXin Li const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
181*3e777be0SXin Li if (!output)
182*3e777be0SXin Li {
183*3e777be0SXin Li return Fail("%s: Could not read output 0", __func__);
184*3e777be0SXin Li }
185*3e777be0SXin Li
186*3e777be0SXin Li const TensorInfo& inputInfo0 = input0.GetTensorInfo();
187*3e777be0SXin Li const TensorInfo& inputInfo1 = input1.GetTensorInfo();
188*3e777be0SXin Li const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
189*3e777be0SXin Li
190*3e777be0SXin Li LogicalBinaryDescriptor descriptor(logicalOperation);
191*3e777be0SXin Li
192*3e777be0SXin Li bool isSupported = false;
193*3e777be0SXin Li armnn::BackendId setBackend;
194*3e777be0SXin Li auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
195*3e777be0SXin Li {
196*3e777be0SXin Li FORWARD_LAYER_SUPPORT_FUNC(__func__,
197*3e777be0SXin Li IsLogicalBinarySupported,
198*3e777be0SXin Li data.m_Backends,
199*3e777be0SXin Li isSupported,
200*3e777be0SXin Li setBackend,
201*3e777be0SXin Li inputInfo0,
202*3e777be0SXin Li inputInfo1,
203*3e777be0SXin Li outputInfo,
204*3e777be0SXin Li descriptor);
205*3e777be0SXin Li };
206*3e777be0SXin Li
207*3e777be0SXin Li if(!IsDynamicTensor(outputInfo))
208*3e777be0SXin Li {
209*3e777be0SXin Li validateFunc(outputInfo, isSupported);
210*3e777be0SXin Li }
211*3e777be0SXin Li else
212*3e777be0SXin Li {
213*3e777be0SXin Li isSupported = AreDynamicTensorsSupported();
214*3e777be0SXin Li }
215*3e777be0SXin Li
216*3e777be0SXin Li if (!isSupported)
217*3e777be0SXin Li {
218*3e777be0SXin Li return false;
219*3e777be0SXin Li }
220*3e777be0SXin Li
221*3e777be0SXin Li IConnectableLayer* layer = data.m_Network->AddLogicalBinaryLayer(descriptor);
222*3e777be0SXin Li layer->SetBackendId(setBackend);
223*3e777be0SXin Li if (!layer)
224*3e777be0SXin Li {
225*3e777be0SXin Li return Fail("%s: Could not add the LogicalBinaryLayer", __func__);
226*3e777be0SXin Li }
227*3e777be0SXin Li
228*3e777be0SXin Li bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
229*3e777be0SXin Li if (!isReshapeSupported)
230*3e777be0SXin Li {
231*3e777be0SXin Li return false;
232*3e777be0SXin Li }
233*3e777be0SXin Li
234*3e777be0SXin Li return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
235*3e777be0SXin Li }
236*3e777be0SXin Li
237*3e777be0SXin Li template<typename HalPolicy,
238*3e777be0SXin Li typename HalOperation = typename HalPolicy::Operation,
239*3e777be0SXin Li typename HalModel = typename HalPolicy::Model>
ConvertQuantizedLstm(const HalOperation & operation,const HalModel & model,ConversionData & data)240*3e777be0SXin Li bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
241*3e777be0SXin Li {
242*3e777be0SXin Li using HalOperand = typename HalPolicy::Operand;
243*3e777be0SXin Li using HalOperandType = typename HalPolicy::OperandType;
244*3e777be0SXin Li
245*3e777be0SXin Li ALOGV("HalPolicy::ConvertQuantizedLstm()");
246*3e777be0SXin Li
247*3e777be0SXin Li //Inputs:
248*3e777be0SXin Li // 0: The input: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBatches, inputSize]
249*3e777be0SXin Li // specifying the input to the LSTM cell. Tensor is quantized with a fixed quantization range of -1, 127/128.
250*3e777be0SXin Li LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
251*3e777be0SXin Li if (!input.IsValid())
252*3e777be0SXin Li {
253*3e777be0SXin Li return Fail("%s: Could not read input 0: input", __func__);
254*3e777be0SXin Li }
255*3e777be0SXin Li
256*3e777be0SXin Li // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, of shape [batch_size, output_size].
257*3e777be0SXin Li LayerInputHandle outputStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
258*3e777be0SXin Li if (!outputStatePrevTimeStep.IsValid())
259*3e777be0SXin Li {
260*3e777be0SXin Li return Fail("%s: Could not read input 18: outputStatePrevTimeStep", __func__);
261*3e777be0SXin Li }
262*3e777be0SXin Li
263*3e777be0SXin Li // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
264*3e777be0SXin Li LayerInputHandle cellStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
265*3e777be0SXin Li if (!cellStatePrevTimeStep.IsValid())
266*3e777be0SXin Li {
267*3e777be0SXin Li return Fail("%s: Could not read input 19: cellStatePrevTimeStep", __func__);
268*3e777be0SXin Li }
269*3e777be0SXin Li
270*3e777be0SXin Li // Get the mandatory input tensors:
271*3e777be0SXin Li
272*3e777be0SXin Li // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
273*3e777be0SXin Li // [num_units, input_size].
274*3e777be0SXin Li const ConstTensorPin inputToForgetWeightsPin =
275*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
276*3e777be0SXin Li
277*3e777be0SXin Li // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
278*3e777be0SXin Li // [num_units, input_size].
279*3e777be0SXin Li const ConstTensorPin inputToCellWeightsPin =
280*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 3, model, data);
281*3e777be0SXin Li
282*3e777be0SXin Li // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
283*3e777be0SXin Li // [num_units, input_size].
284*3e777be0SXin Li const ConstTensorPin inputToOutputWeightsPin =
285*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 4, model, data);
286*3e777be0SXin Li
287*3e777be0SXin Li // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
288*3e777be0SXin Li // [num_units, output_size].
289*3e777be0SXin Li const ConstTensorPin recurrentToForgetWeightsPin =
290*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 6, model, data);
291*3e777be0SXin Li
292*3e777be0SXin Li // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
293*3e777be0SXin Li // [num_units, output_size].
294*3e777be0SXin Li const ConstTensorPin recurrentToCellWeightsPin =
295*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 7, model, data);
296*3e777be0SXin Li
297*3e777be0SXin Li // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
298*3e777be0SXin Li // [num_units, output_size].
299*3e777be0SXin Li const ConstTensorPin recurrentToOutputWeightsPin =
300*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 8, model, data);
301*3e777be0SXin Li
302*3e777be0SXin Li // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
303*3e777be0SXin Li const ConstTensorPin forgetGateBiasPin =
304*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
305*3e777be0SXin Li
306*3e777be0SXin Li // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
307*3e777be0SXin Li const ConstTensorPin cellBiasPin =
308*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
309*3e777be0SXin Li
310*3e777be0SXin Li // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
311*3e777be0SXin Li const ConstTensorPin outputGateBiasPin =
312*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
313*3e777be0SXin Li
314*3e777be0SXin Li if (!inputToForgetWeightsPin.IsValid() ||
315*3e777be0SXin Li !inputToCellWeightsPin.IsValid() ||
316*3e777be0SXin Li !inputToOutputWeightsPin.IsValid() ||
317*3e777be0SXin Li !recurrentToForgetWeightsPin.IsValid() ||
318*3e777be0SXin Li !recurrentToCellWeightsPin.IsValid() ||
319*3e777be0SXin Li !recurrentToOutputWeightsPin.IsValid() ||
320*3e777be0SXin Li !forgetGateBiasPin.IsValid() ||
321*3e777be0SXin Li !cellBiasPin.IsValid() ||
322*3e777be0SXin Li !outputGateBiasPin.IsValid())
323*3e777be0SXin Li {
324*3e777be0SXin Li return Fail("%s: Operation has invalid tensor inputs", __func__);
325*3e777be0SXin Li }
326*3e777be0SXin Li
327*3e777be0SXin Li // Get the optional input tensors:
328*3e777be0SXin Li
329*3e777be0SXin Li // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
330*3e777be0SXin Li // [num_units, input_size], where “num_units” corresponds to the number of cell units.
331*3e777be0SXin Li const ConstTensorPin inputToInputWeightsPin =
332*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
333*3e777be0SXin Li 1,
334*3e777be0SXin Li model,
335*3e777be0SXin Li data,
336*3e777be0SXin Li g_DontPermute,
337*3e777be0SXin Li nullptr,
338*3e777be0SXin Li true);
339*3e777be0SXin Li
340*3e777be0SXin Li // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
341*3e777be0SXin Li // [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
342*3e777be0SXin Li // “num_units”), or the second dimension of the “projection_weights”, if defined.
343*3e777be0SXin Li const ConstTensorPin recurrentToInputWeightsPin =
344*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
345*3e777be0SXin Li 5,
346*3e777be0SXin Li model,
347*3e777be0SXin Li data,
348*3e777be0SXin Li g_DontPermute,
349*3e777be0SXin Li nullptr,
350*3e777be0SXin Li true);
351*3e777be0SXin Li
352*3e777be0SXin Li // 09: The cell-to-input weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
353*3e777be0SXin Li // [num_units].
354*3e777be0SXin Li const ConstTensorPin cellToInputWeightsPin =
355*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
356*3e777be0SXin Li 9,
357*3e777be0SXin Li model,
358*3e777be0SXin Li data,
359*3e777be0SXin Li g_DontPermute,
360*3e777be0SXin Li nullptr,
361*3e777be0SXin Li true);
362*3e777be0SXin Li
363*3e777be0SXin Li // 10: The cell-to-forget weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
364*3e777be0SXin Li // [num_units].
365*3e777be0SXin Li const ConstTensorPin cellToForgetWeightsPin =
366*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
367*3e777be0SXin Li 10,
368*3e777be0SXin Li model,
369*3e777be0SXin Li data,
370*3e777be0SXin Li g_DontPermute,
371*3e777be0SXin Li nullptr,
372*3e777be0SXin Li true);
373*3e777be0SXin Li
374*3e777be0SXin Li // 11: The cell-to-output weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
375*3e777be0SXin Li // [num_units].
376*3e777be0SXin Li const ConstTensorPin cellToOutputWeightsPin =
377*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
378*3e777be0SXin Li 11,
379*3e777be0SXin Li model,
380*3e777be0SXin Li data,
381*3e777be0SXin Li g_DontPermute,
382*3e777be0SXin Li nullptr,
383*3e777be0SXin Li true);
384*3e777be0SXin Li
385*3e777be0SXin Li // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
386*3e777be0SXin Li const ConstTensorPin inputGateBiasPin =
387*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
388*3e777be0SXin Li 12,
389*3e777be0SXin Li model,
390*3e777be0SXin Li data,
391*3e777be0SXin Li g_DontPermute,
392*3e777be0SXin Li nullptr,
393*3e777be0SXin Li true);
394*3e777be0SXin Li
395*3e777be0SXin Li // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
396*3e777be0SXin Li // [output_size, num_units].
397*3e777be0SXin Li const ConstTensorPin projectionWeightsPin =
398*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
399*3e777be0SXin Li 16,
400*3e777be0SXin Li model,
401*3e777be0SXin Li data,
402*3e777be0SXin Li g_DontPermute,
403*3e777be0SXin Li nullptr,
404*3e777be0SXin Li true);
405*3e777be0SXin Li
406*3e777be0SXin Li // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [output_size].
407*3e777be0SXin Li const ConstTensorPin projectionBiasPin =
408*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
409*3e777be0SXin Li 17,
410*3e777be0SXin Li model,
411*3e777be0SXin Li data,
412*3e777be0SXin Li g_DontPermute,
413*3e777be0SXin Li nullptr,
414*3e777be0SXin Li true);
415*3e777be0SXin Li
416*3e777be0SXin Li if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional())
417*3e777be0SXin Li || (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional())
418*3e777be0SXin Li || (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional())
419*3e777be0SXin Li || (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional())
420*3e777be0SXin Li || (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional())
421*3e777be0SXin Li || (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional())
422*3e777be0SXin Li || (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional())
423*3e777be0SXin Li || (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
424*3e777be0SXin Li {
425*3e777be0SXin Li return Fail("%s: Operation has invalid tensor inputs", __func__);
426*3e777be0SXin Li }
427*3e777be0SXin Li
428*3e777be0SXin Li
429*3e777be0SXin Li // Get the optional normalization tensors
430*3e777be0SXin Li
431*3e777be0SXin Li // 20: The input layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
432*3e777be0SXin Li // Used to rescale normalized inputs to activation at input gate.
433*3e777be0SXin Li const ConstTensorPin inputLayerNormWeightsPin =
434*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
435*3e777be0SXin Li 20,
436*3e777be0SXin Li model,
437*3e777be0SXin Li data,
438*3e777be0SXin Li g_DontPermute,
439*3e777be0SXin Li nullptr,
440*3e777be0SXin Li true);
441*3e777be0SXin Li
442*3e777be0SXin Li // 21: The forget layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM
443*3e777be0SXin Li // Used to rescale normalized inputs to activation at forget gate.
444*3e777be0SXin Li const ConstTensorPin forgetLayerNormWeightsPin =
445*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
446*3e777be0SXin Li 21,
447*3e777be0SXin Li model,
448*3e777be0SXin Li data,
449*3e777be0SXin Li g_DontPermute,
450*3e777be0SXin Li nullptr,
451*3e777be0SXin Li true);
452*3e777be0SXin Li
453*3e777be0SXin Li // 22: The cell layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
454*3e777be0SXin Li // Used to rescale normalized inputs to activation at cell gate.
455*3e777be0SXin Li const ConstTensorPin cellLayerNormWeightsPin =
456*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
457*3e777be0SXin Li 22,
458*3e777be0SXin Li model,
459*3e777be0SXin Li data,
460*3e777be0SXin Li g_DontPermute,
461*3e777be0SXin Li nullptr,
462*3e777be0SXin Li true);
463*3e777be0SXin Li
464*3e777be0SXin Li // 23: The output layer normalization weights. A 1-D tensor of shape [num_units].
465*3e777be0SXin Li // Used to rescale normalized inputs to activation at output gate.
466*3e777be0SXin Li const ConstTensorPin outputLayerNormWeightsPin =
467*3e777be0SXin Li ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
468*3e777be0SXin Li 23,
469*3e777be0SXin Li model,
470*3e777be0SXin Li data,
471*3e777be0SXin Li g_DontPermute,
472*3e777be0SXin Li nullptr,
473*3e777be0SXin Li true);
474*3e777be0SXin Li
475*3e777be0SXin Li if ((!inputLayerNormWeightsPin.IsValid() && !inputLayerNormWeightsPin.IsOptional())
476*3e777be0SXin Li || (!forgetLayerNormWeightsPin.IsValid() && !forgetLayerNormWeightsPin.IsOptional())
477*3e777be0SXin Li || (!cellLayerNormWeightsPin.IsValid() && !cellLayerNormWeightsPin.IsOptional())
478*3e777be0SXin Li || (!outputLayerNormWeightsPin.IsValid() && !outputLayerNormWeightsPin.IsOptional()))
479*3e777be0SXin Li {
480*3e777be0SXin Li return Fail("%s: Operation has invalid tensor inputs", __func__);
481*3e777be0SXin Li }
482*3e777be0SXin Li
483*3e777be0SXin Li // Get the optional input scalars:
484*3e777be0SXin Li // 24: The cell clip: If provided the cell state is clipped by this value prior to the cell output activation.
485*3e777be0SXin Li // 25: The projection clip: If provided and projection is enabled, this is used for clipping the projected values.
486*3e777be0SXin Li
487*3e777be0SXin Li // Get the mandatory input scalars:
488*3e777be0SXin Li // 26: The scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
489*3e777be0SXin Li // 27: The scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
490*3e777be0SXin Li // 28: The scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
491*3e777be0SXin Li // 29: The scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
492*3e777be0SXin Li // 30: The zero point of the hidden state, i.e. input to projection.
493*3e777be0SXin Li // 31: The scale of the hidden state, i.e. input to projection.
494*3e777be0SXin Li float cellClip, projClip, matMulInputGate, matMulForgetGate, matMulCellGate, matMulOutputGate, projInputScale;
495*3e777be0SXin Li int projInputZeroPoint;
496*3e777be0SXin Li
497*3e777be0SXin Li if (!GetInputScalar<HalPolicy>(operation, 24, HalOperandType::FLOAT32, cellClip, model, data, true) ||
498*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 25, HalOperandType::FLOAT32, projClip, model, data, true) ||
499*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 26, HalOperandType::FLOAT32, matMulInputGate, model, data) ||
500*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 27, HalOperandType::FLOAT32, matMulForgetGate, model, data) ||
501*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 28, HalOperandType::FLOAT32, matMulCellGate, model, data) ||
502*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 29, HalOperandType::FLOAT32, matMulOutputGate, model, data) ||
503*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 30, HalOperandType::INT32, projInputZeroPoint, model, data) ||
504*3e777be0SXin Li !GetInputScalar<HalPolicy>(operation, 31, HalOperandType::FLOAT32, projInputScale, model, data))
505*3e777be0SXin Li {
506*3e777be0SXin Li return Fail("%s: Operation has invalid scalar inputs", __func__);
507*3e777be0SXin Li }
508*3e777be0SXin Li
509*3e777be0SXin Li // Outputs:
510*3e777be0SXin Li // 0: The output state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size,
511*3e777be0SXin Li // output_size].
512*3e777be0SXin Li const HalOperand* outputStateOut = GetOutputOperand<HalPolicy>(operation, 0, model);
513*3e777be0SXin Li if (!outputStateOut)
514*3e777be0SXin Li {
515*3e777be0SXin Li return Fail("%s: Could not read output 0: outputStateOut", __func__);
516*3e777be0SXin Li }
517*3e777be0SXin Li
518*3e777be0SXin Li // 1: The cell state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
519*3e777be0SXin Li const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 1, model);
520*3e777be0SXin Li if (!cellStateOut)
521*3e777be0SXin Li {
522*3e777be0SXin Li return Fail("%s: Could not read output 1: cellStateOut", __func__);
523*3e777be0SXin Li }
524*3e777be0SXin Li
525*3e777be0SXin Li // 2: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size, output_size].
526*3e777be0SXin Li // This is effectively the same as the current “output state (out)” value.
527*3e777be0SXin Li const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 2, model);
528*3e777be0SXin Li if (!output)
529*3e777be0SXin Li {
530*3e777be0SXin Li return Fail("%s: Could not read output 2: output", __func__);
531*3e777be0SXin Li }
532*3e777be0SXin Li
533*3e777be0SXin Li // set the params structure for the AddLstmLayer call
534*3e777be0SXin Li LstmInputParams params;
535*3e777be0SXin Li params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
536*3e777be0SXin Li params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
537*3e777be0SXin Li params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
538*3e777be0SXin Li params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
539*3e777be0SXin Li params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
540*3e777be0SXin Li params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
541*3e777be0SXin Li params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
542*3e777be0SXin Li params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
543*3e777be0SXin Li params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
544*3e777be0SXin Li params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
545*3e777be0SXin Li params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
546*3e777be0SXin Li params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
547*3e777be0SXin Li params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
548*3e777be0SXin Li params.m_CellBias = cellBiasPin.GetConstTensorPtr();
549*3e777be0SXin Li params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
550*3e777be0SXin Li params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
551*3e777be0SXin Li params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
552*3e777be0SXin Li params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
553*3e777be0SXin Li params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
554*3e777be0SXin Li params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
555*3e777be0SXin Li params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
556*3e777be0SXin Li
557*3e777be0SXin Li // set the layer descriptor
558*3e777be0SXin Li QLstmDescriptor desc;
559*3e777be0SXin Li desc.m_CellClip = cellClip;
560*3e777be0SXin Li desc.m_ProjectionClip = projClip;
561*3e777be0SXin Li desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
562*3e777be0SXin Li params.m_RecurrentToInputWeights == nullptr ||
563*3e777be0SXin Li params.m_InputGateBias == nullptr);
564*3e777be0SXin Li desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
565*3e777be0SXin Li params.m_CellToOutputWeights != nullptr);
566*3e777be0SXin Li desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
567*3e777be0SXin Li desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
568*3e777be0SXin Li params.m_ForgetLayerNormWeights != nullptr ||
569*3e777be0SXin Li params.m_CellLayerNormWeights != nullptr ||
570*3e777be0SXin Li params.m_OutputLayerNormWeights != nullptr);
571*3e777be0SXin Li desc.m_InputIntermediateScale = matMulInputGate;
572*3e777be0SXin Li desc.m_ForgetIntermediateScale = matMulForgetGate;
573*3e777be0SXin Li desc.m_CellIntermediateScale = matMulCellGate;
574*3e777be0SXin Li desc.m_OutputIntermediateScale = matMulOutputGate;
575*3e777be0SXin Li desc.m_HiddenStateScale = projInputScale;
576*3e777be0SXin Li desc.m_HiddenStateZeroPoint = projInputZeroPoint;
577*3e777be0SXin Li
578*3e777be0SXin Li // validate the optional input groups
579*3e777be0SXin Li if (desc.m_CifgEnabled &&
580*3e777be0SXin Li (params.m_InputToInputWeights != nullptr ||
581*3e777be0SXin Li params.m_RecurrentToInputWeights != nullptr ||
582*3e777be0SXin Li params.m_InputGateBias != nullptr))
583*3e777be0SXin Li {
584*3e777be0SXin Li return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
585*3e777be0SXin Li " and input gate bias must be provided", __func__);
586*3e777be0SXin Li }
587*3e777be0SXin Li
588*3e777be0SXin Li if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
589*3e777be0SXin Li {
590*3e777be0SXin Li return Fail("%s: projection bias should not be provided without projection weights", __func__);
591*3e777be0SXin Li }
592*3e777be0SXin Li
593*3e777be0SXin Li if (desc.m_PeepholeEnabled &&
594*3e777be0SXin Li (params.m_CellToForgetWeights == nullptr ||
595*3e777be0SXin Li params.m_CellToOutputWeights == nullptr ||
596*3e777be0SXin Li (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
597*3e777be0SXin Li {
598*3e777be0SXin Li return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
599*3e777be0SXin Li " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
600*3e777be0SXin Li }
601*3e777be0SXin Li
602*3e777be0SXin Li if (desc.m_LayerNormEnabled &&
603*3e777be0SXin Li (params.m_ForgetLayerNormWeights == nullptr ||
604*3e777be0SXin Li params.m_CellLayerNormWeights == nullptr ||
605*3e777be0SXin Li params.m_OutputLayerNormWeights == nullptr ||
606*3e777be0SXin Li (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
607*3e777be0SXin Li {
608*3e777be0SXin Li return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
609*3e777be0SXin Li " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
610*3e777be0SXin Li }
611*3e777be0SXin Li
612*3e777be0SXin Li
613*3e777be0SXin Li // Basic parameters
614*3e777be0SXin Li LstmInputParamsInfo paramsInfo;
615*3e777be0SXin Li paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
616*3e777be0SXin Li paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
617*3e777be0SXin Li paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
618*3e777be0SXin Li paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
619*3e777be0SXin Li paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
620*3e777be0SXin Li paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
621*3e777be0SXin Li paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
622*3e777be0SXin Li paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
623*3e777be0SXin Li paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
624*3e777be0SXin Li
625*3e777be0SXin Li // Inputs
626*3e777be0SXin Li const TensorInfo& inputInfo = input.GetTensorInfo();
627*3e777be0SXin Li const TensorInfo& outputStatePrevTimeStepInfo = outputStatePrevTimeStep.GetTensorInfo();
628*3e777be0SXin Li const TensorInfo& cellStatePrevTimeStepInfo = cellStatePrevTimeStep.GetTensorInfo();
629*3e777be0SXin Li
630*3e777be0SXin Li // Outputs
631*3e777be0SXin Li TensorInfo outputStateOutInfo = GetTensorInfoForOperand(*outputStateOut);
632*3e777be0SXin Li TensorInfo outputInfo = GetTensorInfoForOperand(*output);
633*3e777be0SXin Li const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
634*3e777be0SXin Li
635*3e777be0SXin Li // Optional parameters
636*3e777be0SXin Li if (!desc.m_CifgEnabled)
637*3e777be0SXin Li {
638*3e777be0SXin Li paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
639*3e777be0SXin Li paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
640*3e777be0SXin Li if (desc.m_PeepholeEnabled)
641*3e777be0SXin Li {
642*3e777be0SXin Li paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
643*3e777be0SXin Li }
644*3e777be0SXin Li paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
645*3e777be0SXin Li }
646*3e777be0SXin Li
647*3e777be0SXin Li
648*3e777be0SXin Li if (desc.m_ProjectionEnabled)
649*3e777be0SXin Li {
650*3e777be0SXin Li paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
651*3e777be0SXin Li if (params.m_ProjectionBias != nullptr)
652*3e777be0SXin Li {
653*3e777be0SXin Li paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
654*3e777be0SXin Li }
655*3e777be0SXin Li }
656*3e777be0SXin Li else
657*3e777be0SXin Li {
658*3e777be0SXin Li // If Projection is disabled, override non-const outputs to change the quant info with hidden params, then
659*3e777be0SXin Li // create a new const TensorInfo based on this
660*3e777be0SXin Li outputStateOutInfo.SetQuantizationScale(projInputScale);
661*3e777be0SXin Li outputStateOutInfo.SetQuantizationOffset(projInputZeroPoint);
662*3e777be0SXin Li outputInfo.SetQuantizationScale(projInputScale);
663*3e777be0SXin Li outputInfo.SetQuantizationOffset(projInputZeroPoint);
664*3e777be0SXin Li }
665*3e777be0SXin Li
666*3e777be0SXin Li const TensorInfo constOutputStateOutInfo(outputStateOutInfo);
667*3e777be0SXin Li const TensorInfo constOutputInfo(outputInfo);
668*3e777be0SXin Li
669*3e777be0SXin Li if (desc.m_PeepholeEnabled)
670*3e777be0SXin Li {
671*3e777be0SXin Li paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
672*3e777be0SXin Li paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
673*3e777be0SXin Li }
674*3e777be0SXin Li
675*3e777be0SXin Li if (desc.m_LayerNormEnabled)
676*3e777be0SXin Li {
677*3e777be0SXin Li if(!desc.m_CifgEnabled)
678*3e777be0SXin Li {
679*3e777be0SXin Li paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
680*3e777be0SXin Li }
681*3e777be0SXin Li paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
682*3e777be0SXin Li paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
683*3e777be0SXin Li paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
684*3e777be0SXin Li }
685*3e777be0SXin Li
686*3e777be0SXin Li // Check if the layer is supported
687*3e777be0SXin Li bool isSupported = false;
688*3e777be0SXin Li armnn::BackendId setBackend;
689*3e777be0SXin Li auto validateFunc = [&](const armnn::TensorInfo& cellStateOutInfo, bool& isSupported)
690*3e777be0SXin Li {
691*3e777be0SXin Li FORWARD_LAYER_SUPPORT_FUNC(__func__,
692*3e777be0SXin Li IsQLstmSupported,
693*3e777be0SXin Li data.m_Backends,
694*3e777be0SXin Li isSupported,
695*3e777be0SXin Li setBackend,
696*3e777be0SXin Li inputInfo,
697*3e777be0SXin Li outputStatePrevTimeStepInfo,
698*3e777be0SXin Li cellStatePrevTimeStepInfo,
699*3e777be0SXin Li constOutputStateOutInfo,
700*3e777be0SXin Li cellStateOutInfo,
701*3e777be0SXin Li constOutputInfo,
702*3e777be0SXin Li desc,
703*3e777be0SXin Li paramsInfo);
704*3e777be0SXin Li };
705*3e777be0SXin Li
706*3e777be0SXin Li bool isDynamic = false;
707*3e777be0SXin Li if (!IsDynamicTensor(constOutputStateOutInfo) &&
708*3e777be0SXin Li !IsDynamicTensor(cellStateOutInfo) &&
709*3e777be0SXin Li !IsDynamicTensor(constOutputInfo))
710*3e777be0SXin Li {
711*3e777be0SXin Li validateFunc(outputInfo, isSupported);
712*3e777be0SXin Li }
713*3e777be0SXin Li else
714*3e777be0SXin Li {
715*3e777be0SXin Li isDynamic = true;
716*3e777be0SXin Li isSupported = AreDynamicTensorsSupported();
717*3e777be0SXin Li }
718*3e777be0SXin Li
719*3e777be0SXin Li if (!isSupported)
720*3e777be0SXin Li {
721*3e777be0SXin Li return false;
722*3e777be0SXin Li }
723*3e777be0SXin Li
724*3e777be0SXin Li // Add the layer
725*3e777be0SXin Li IConnectableLayer* layer = data.m_Network->AddQLstmLayer(desc, params, "QLstm");
726*3e777be0SXin Li layer->SetBackendId(setBackend);
727*3e777be0SXin Li
728*3e777be0SXin Li input.Connect(layer->GetInputSlot(0));
729*3e777be0SXin Li outputStatePrevTimeStep.Connect(layer->GetInputSlot(1));
730*3e777be0SXin Li cellStatePrevTimeStep.Connect(layer->GetInputSlot(2));
731*3e777be0SXin Li
732*3e777be0SXin Li if (!isDynamic)
733*3e777be0SXin Li {
734*3e777be0SXin Li return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
735*3e777be0SXin Li operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
736*3e777be0SXin Li SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
737*3e777be0SXin Li SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
738*3e777be0SXin Li }
739*3e777be0SXin Li else
740*3e777be0SXin Li {
741*3e777be0SXin Li return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
742*3e777be0SXin Li operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
743*3e777be0SXin Li SetupAndTrackLayerOutputSlot<HalPolicy>(
744*3e777be0SXin Li operation, 1, *layer, 1, model, data, nullptr, validateFunc,
745*3e777be0SXin Li ActivationFn::kActivationNone, true) &&
746*3e777be0SXin Li SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
747*3e777be0SXin Li }
748*3e777be0SXin Li }
749*3e777be0SXin Li
750*3e777be0SXin Li template<typename HalPolicy,
751*3e777be0SXin Li typename HalOperation = typename HalPolicy::Operation,
752*3e777be0SXin Li typename HalModel = typename HalPolicy::Model>
ConvertRank(const HalOperation & operation,const HalModel & model,ConversionData & data)753*3e777be0SXin Li bool ConvertRank(const HalOperation& operation, const HalModel& model, ConversionData& data)
754*3e777be0SXin Li {
755*3e777be0SXin Li using HalOperand = typename HalPolicy::Operand;
756*3e777be0SXin Li
757*3e777be0SXin Li const HalOperand* inputOperand = GetInputOperand<HalPolicy>(operation, 0, model);
758*3e777be0SXin Li const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
759*3e777be0SXin Li
760*3e777be0SXin Li if (inputOperand == nullptr || outputOperand == nullptr)
761*3e777be0SXin Li {
762*3e777be0SXin Li return Fail("%s: Operation has invalid inputs", __func__);
763*3e777be0SXin Li }
764*3e777be0SXin Li
765*3e777be0SXin Li const Shape inputOperandShape = GetOperandShape(*inputOperand);
766*3e777be0SXin Li const Shape outputOperandShape = GetOperandShape(*outputOperand);
767*3e777be0SXin Li
768*3e777be0SXin Li LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
769*3e777be0SXin Li if (!input.IsValid())
770*3e777be0SXin Li {
771*3e777be0SXin Li return Fail("%s: Could not read input 0", __func__);
772*3e777be0SXin Li }
773*3e777be0SXin Li
774*3e777be0SXin Li armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand);
775*3e777be0SXin Li if (IsDynamicTensor(outInfo))
776*3e777be0SXin Li {
777*3e777be0SXin Li return Fail("%s: Dynamic output tensors are not supported", __func__);
778*3e777be0SXin Li }
779*3e777be0SXin Li
780*3e777be0SXin Li bool isSupported = false;
781*3e777be0SXin Li armnn::BackendId setBackend;
782*3e777be0SXin Li FORWARD_LAYER_SUPPORT_FUNC(__func__,
783*3e777be0SXin Li IsRankSupported,
784*3e777be0SXin Li data.m_Backends,
785*3e777be0SXin Li isSupported,
786*3e777be0SXin Li setBackend,
787*3e777be0SXin Li input.GetTensorInfo(),
788*3e777be0SXin Li outInfo);
789*3e777be0SXin Li if (!isSupported)
790*3e777be0SXin Li {
791*3e777be0SXin Li return false;
792*3e777be0SXin Li }
793*3e777be0SXin Li
794*3e777be0SXin Li armnn::IConnectableLayer* layer = data.m_Network->AddRankLayer();
795*3e777be0SXin Li layer->SetBackendId(setBackend);
796*3e777be0SXin Li if (!layer)
797*3e777be0SXin Li {
798*3e777be0SXin Li return Fail("%s: Could not add the RankLayer", __func__);
799*3e777be0SXin Li }
800*3e777be0SXin Li input.Connect(layer->GetInputSlot(0));
801*3e777be0SXin Li
802*3e777be0SXin Li return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, &outInfo);
803*3e777be0SXin Li }
804*3e777be0SXin Li
805*3e777be0SXin Li } // armnn_driver namespace
806