xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefQLstmWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefQLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "LstmUtils.hpp"
11 #include "RefWorkloadUtils.hpp"
12 
13 namespace armnn
14 {
15 
RefQLstmWorkload(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info)16 RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17         : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
18         , m_InputToInputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19         , m_InputToForgetWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20         , m_InputToCellWeightsTensor      (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21         , m_InputToOutputWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
22 
23         , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24         , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25         , m_RecurrentToCellWeightsTensor  (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26         , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 
28         , m_CellToInputWeightsTensor      (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
29         , m_CellToForgetWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
30         , m_CellToOutputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
31 
32         , m_InputGateBiasTensor           (AssignScopedTensorHandle(descriptor.m_InputGateBias))
33         , m_ForgetGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
34         , m_CellBiasTensor                (AssignScopedTensorHandle(descriptor.m_CellBias))
35         , m_OutputGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
36 
37         , m_ProjectionWeightsTensor       (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
38         , m_ProjectionBiasTensor          (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
39 
40         , m_InputLayerNormWeightsTensor   (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41         , m_ForgetLayerNormWeightsTensor  (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42         , m_CellLayerNormWeightsTensor    (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43         , m_OutputLayerNormWeightsTensor  (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45 
Execute() const46 void RefQLstmWorkload::Execute() const
47 {
48     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49 }
50 
ExecuteAsync(ExecutionData & executionData)51 void RefQLstmWorkload::ExecuteAsync(ExecutionData& executionData)
52 {
53     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
55 }
56 
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const57 void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
58 {
59     // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
60     // method in the Android code base
61     // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
62     // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
63     // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
64     const DataType& internalType = armnn::DataType::QSymmS16;
65 
66     const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
67     const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]);
68     const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]);
69 
70     const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]);
71     const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]);
72     const TensorInfo& outputInfo = GetTensorInfo(outputs[2]);
73 
74     const TensorShape& inputShape = inputInfo.GetShape();
75     const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
76     const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
77 
78     // Infer numBatches, inputSize, outputSize and numUnits
79     const uint32_t numBatches = inputShape[0];
80     const uint32_t inputSize  = inputShape[1];
81     const uint32_t outputSize = outputStateInShape[1];
82     const uint32_t numUnits   = cellStateInShape[1];
83 
84     // Optional param settings
85     const bool cifgEnabled      = m_Data.m_Parameters.m_CifgEnabled;
86     const bool peepholeEnabled  = m_Data.m_Parameters.m_PeepholeEnabled;
87     const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
88     const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
89 
90     // Input decoders
91     std::unique_ptr<Decoder<float>> inputDecoder =
92             MakeDecoder<float>(inputInfo, inputs[0]->Map());
93     std::unique_ptr<Decoder<float>> outputStateInDecoder =
94             MakeDecoder<float>(outputStateInInfo, inputs[1]->Map());
95     std::unique_ptr<Decoder<float>> cellStateInDecoder =
96             MakeDecoder<float>(cellStateInInfo, inputs[2]->Map());
97 
98     // Output decoders
99     std::unique_ptr<Decoder<float>> outputStateOutDecoder =
100             MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map());
101     std::unique_ptr<Decoder<float>> cellStateOutDecoder =
102             MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map());
103     std::unique_ptr<Decoder<float>> outputDecoder =
104             MakeDecoder<float>(outputInfo, outputs[2]->Map());
105 
106     // Output encoders
107     std::unique_ptr<Encoder<float>> outputStateOutEncoder =
108             MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map());
109     std::unique_ptr<Encoder<float>> cellStateOutEncoder =
110             MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map());
111     std::unique_ptr<Encoder<float>> outputEncoder =
112             MakeEncoder<float>(outputInfo, outputs[2]->Map());
113 
114     // Weights decoders
115     std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
116             m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
117     std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
118             m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
119     std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
120             m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
121 
122     std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
123             m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
124             m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
125     std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
126             m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
127     std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
128             m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
129             m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
130 
131     // Optional CIFG params
132     std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
133     std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
134     std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
135 
136     // Optional Peephole params
137     std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
138     std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
139     std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
140 
141     // Optional Projection params
142     std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
143     std::unique_ptr<Decoder<float>> projectionBiasDecoder;
144 
145     // Optional Layer Norm params
146     std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
147     std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
148     std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
149     std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
150 
151     // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
152     std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
153     std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
154     std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
155 
156     // Int16 vectors for internal state data (to be decoded/encoded)
157     const uint32_t stateTensorSize = numBatches * numUnits;
158     std::vector<int16_t> inputGateData(stateTensorSize);
159     std::vector<int16_t> cellGateData(stateTensorSize);
160     std::vector<int16_t> forgetGateData(stateTensorSize);
161     std::vector<int16_t> outputGateData(stateTensorSize);
162     std::vector<int32_t> hiddenStateData(stateTensorSize);
163     std::vector<int16_t> outputInt16Data(numBatches * outputSize);
164 
165     armnn::TensorInfo inputGateInfo(
166             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
167     armnn::TensorInfo cellGateInfo(
168             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
169     armnn::TensorInfo forgetGateInfo(
170             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
171     armnn::TensorInfo outputGateInfo(
172             {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
173     armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
174                                       armnn::DataType::QAsymmS8,
175                                       m_Data.m_Parameters.m_HiddenStateScale,
176                                       m_Data.m_Parameters.m_HiddenStateZeroPoint);
177     armnn::TensorInfo outputInt16Info({numBatches , outputSize},
178                                       armnn::DataType::QSymmS16,
179                                       outputInfo.GetQuantizationScale(),
180                                       outputInfo.GetQuantizationOffset());
181 
182     // Decoders/Encoders for internal states
183     std::unique_ptr<Decoder<float>> inputGateDecoder =
184             MakeDecoder<float>(inputGateInfo, inputGateData.data());
185     std::unique_ptr<Decoder<float>> cellGateDecoder =
186             MakeDecoder<float>(cellGateInfo, cellGateData.data());
187     std::unique_ptr<Decoder<float>> forgetGateDecoder =
188             MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
189     std::unique_ptr<Decoder<float>> outputGateDecoder =
190             MakeDecoder<float>(outputGateInfo, outputGateData.data());
191     std::unique_ptr<Decoder<float>> hiddenStateDecoder =
192             MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
193 
194     std::unique_ptr<Encoder<float>> inputGateEncoder =
195             MakeEncoder<float>(inputGateInfo, inputGateData.data());
196     std::unique_ptr<Encoder<float>> cellGateEncoder =
197             MakeEncoder<float>(cellGateInfo, cellGateData.data());
198     std::unique_ptr<Encoder<float>> forgetGateEncoder =
199             MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
200     std::unique_ptr<Encoder<float>> outputGateEncoder =
201             MakeEncoder<float>(outputGateInfo, outputGateData.data());
202     std::unique_ptr<Encoder<float>> hiddenStateEncoder =
203             MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
204 
205     // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
206     std::unique_ptr<Decoder<float>> outputInt16Decoder =
207             MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
208     std::unique_ptr<Encoder<float>> outputInt16Encoder =
209             MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
210 
211     // Create decoders for optional params if they are enabled
212     if (!cifgEnabled)
213     {
214         inputToInputWeightsDecoder = MakeDecoder<float>(
215                 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
216         recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
217                                                             m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
218     }
219 
220     if (peepholeEnabled)
221     {
222         if (!cifgEnabled)
223         {
224             cellToInputWeightsDecoder = MakeDecoder<float>(
225                     m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
226         }
227         cellToForgetWeightsDecoder = MakeDecoder<float>(
228                 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
229         cellToOutputWeightsDecoder = MakeDecoder<float>(
230                 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
231     }
232 
233     if (projectionEnabled)
234     {
235         projectionWeightsDecoder = MakeDecoder<float>(
236                 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
237         if (m_ProjectionBiasTensor)
238         {
239             projectionBiasDecoder = MakeDecoder<float>(
240                     m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
241         }
242     }
243 
244     if (layerNormEnabled)
245     {
246         if (!cifgEnabled)
247         {
248             inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
249                                                               m_InputLayerNormWeightsTensor->GetConstTensor<void>());
250 
251             // Bias only used if layer norm enabled
252             armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
253                     m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
254             inputGateBiasDecoder = MakeDecoder<float>(
255                     inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
256         }
257 
258         forgetLayerNormWeightsDecoder = MakeDecoder<float>(
259                 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
260                 m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
261         cellLayerNormWeightsDecoder = MakeDecoder<float>(
262                 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
263         outputLayerNormWeightsDecoder = MakeDecoder<float>(
264                 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
265                 m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
266 
267         // Bias only used if layer norm enabled
268         armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
269                 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
270         forgetGateBiasDecoder = MakeDecoder<float>(
271                 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
272 
273         armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
274                 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
275         cellGateBiasDecoder = MakeDecoder<float>(
276                 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
277 
278         armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
279                 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
280         outputGateBiasDecoder = MakeDecoder<float>(
281                 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
282     }
283 
284     // Initialize internal state tensors with zeroes.
285     if (!cifgEnabled)
286     {
287         ZeroVector(*inputGateEncoder, stateTensorSize);
288     }
289     ZeroVector(*forgetGateEncoder, stateTensorSize);
290     ZeroVector(*cellGateEncoder, stateTensorSize);
291     ZeroVector(*outputGateEncoder, stateTensorSize);
292     ZeroVector(*hiddenStateEncoder, stateTensorSize);
293 
294     // Input weights * Input
295     if (!cifgEnabled)
296     {
297         MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
298                                             numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
299     }
300 
301     MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
302                                         numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
303 
304     MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
305                                         numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
306 
307     MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
308                                         numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
309 
310     // Recurrent weights * OutputStateIn
311     if (!cifgEnabled)
312     {
313         MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
314                                             numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
315     }
316 
317     MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
318                                         numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
319 
320     MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
321                                         numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
322 
323     MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
324                                         numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
325 
326     // Input gate.
327     if (!cifgEnabled)
328     {
329         if (peepholeEnabled)
330         {
331             VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
332                                                     numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
333         }
334 
335         if (layerNormEnabled)
336         {
337             inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
338                                                m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
339                                                1024);
340             inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
341 
342             MeanStddevNormalization(*inputGateDecoder,
343                                     *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
344 
345             inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
346 
347             VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
348                                           numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
349 
350             inputGateInfo.SetQuantizationScale(1.f / 4096);
351             inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
352 
353             VectorBatchVectorAdd(*inputGateBiasDecoder,
354                                  numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
355 
356             inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
357         }
358 
359         inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
360         inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
361 
362         // Input gate sigmoid
363         Activation(*inputGateDecoder, *inputGateEncoder,
364                    TensorInfo({numUnits, numBatches}, internalType),
365                    ActivationFunction::Sigmoid, 0, 0);
366 
367         inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
368     }
369 
370     // Forget gate
371     if (peepholeEnabled)
372     {
373         VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
374                                                 *cellStateInDecoder, numBatches, *forgetGateEncoder);
375     }
376 
377     if (layerNormEnabled)
378     {
379         // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
380         forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
381                                             m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
382                                             1024);
383         forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
384 
385 
386 
387         MeanStddevNormalization(*forgetGateDecoder,
388                                 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
389 
390 
391         forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
392 
393         VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
394                                       numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
395 
396 
397         // Dequantize layer norm output to (1 / 4096)
398         forgetGateInfo.SetQuantizationScale(1.f / 4096);
399         forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
400 
401         VectorBatchVectorAdd(*forgetGateBiasDecoder,
402                              numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
403 
404 
405         forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
406     }
407 
408     forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
409     forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
410 
411     // Forget gate sigmoid
412     Activation(*forgetGateDecoder, *forgetGateEncoder,
413                TensorInfo({numUnits, numBatches}, internalType),
414                ActivationFunction::Sigmoid, 0, 0);
415 
416     forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
417 
418     // Cell (Modulation) gate
419     if (layerNormEnabled)
420     {
421         cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
422                                           m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
423                                           1024);
424         cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
425 
426         MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
427 
428         cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
429 
430         VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
431                                       numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
432 
433         cellGateInfo.SetQuantizationScale(1.f / 4096);
434         cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
435 
436         VectorBatchVectorAdd(*cellGateBiasDecoder,
437                              numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
438 
439         cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
440     }
441 
442     cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
443     cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
444 
445     // Cell (Modulation) gate tanH
446     Activation(*cellGateDecoder, *cellGateEncoder,
447                TensorInfo({numUnits, numBatches}, internalType),
448                ActivationFunction::TanH, 1.0f, 1.0f);
449 
450     cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
451 
452     VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
453 
454     if (cifgEnabled)
455     {
456         Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
457         VectorVectorCwiseProductAccumulate(
458                 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
459     }
460     else
461     {
462         VectorVectorCwiseProductAccumulate(
463                 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
464     }
465 
466     // Final cell state out calculated here
467     if (m_Data.m_Parameters.m_CellClip > 0.0)
468     {
469         ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
470     }
471 
472     // Output gate.
473     if (peepholeEnabled)
474     {
475         VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
476                                                 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
477     }
478 
479     if (layerNormEnabled)
480     {
481         outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
482                                             m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
483                                             1024);
484         outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
485 
486         MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
487 
488         outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
489 
490         VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
491                                       numBatches, *outputGateEncoder);
492 
493         outputGateInfo.SetQuantizationScale(1.f / 4096);
494         outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
495 
496         VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
497 
498         outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
499     }
500 
501     outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
502     outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
503 
504     // Output gate sigmoid
505     Activation(*outputGateDecoder, *outputGateEncoder,
506                TensorInfo({numUnits, numBatches}, internalType),
507                ActivationFunction::Sigmoid, 0, 0);
508 
509     outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
510 
511     // Hidden state tanH
512     Activation(*cellStateOutDecoder, *cellGateEncoder,
513                TensorInfo({numUnits, numBatches}, internalType),
514                ActivationFunction::TanH, 1.0f, 1.0f);
515 
516     // Final hidden state output
517     VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
518 
519     // Projection
520     if (m_Data.m_Parameters.m_ProjectionEnabled)
521     {
522         if (m_ProjectionBiasTensor)
523         {
524             VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
525         }
526 
527         MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
528                                             numBatches, *outputInt16Encoder);
529 
530         CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
531 
532         if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
533         {
534             ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
535         }
536     }
537     else
538     {
539         // Output has same quantization scale as hidden state if projection is disabled
540         CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
541     }
542 
543     // output == outputStateOut
544     CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
545 }
546 
547 } //namespace armnn
548