1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefQLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "LstmUtils.hpp"
11 #include "RefWorkloadUtils.hpp"
12
13 namespace armnn
14 {
15
RefQLstmWorkload(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info)16 RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17 : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
18 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
22
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27
28 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
31
32 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
36
37 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
39
40 , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44 {}
45
Execute() const46 void RefQLstmWorkload::Execute() const
47 {
48 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49 }
50
ExecuteAsync(ExecutionData & executionData)51 void RefQLstmWorkload::ExecuteAsync(ExecutionData& executionData)
52 {
53 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
55 }
56
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const57 void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
58 {
59 // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
60 // method in the Android code base
61 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
62 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
63 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
64 const DataType& internalType = armnn::DataType::QSymmS16;
65
66 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
67 const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]);
68 const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]);
69
70 const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]);
71 const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]);
72 const TensorInfo& outputInfo = GetTensorInfo(outputs[2]);
73
74 const TensorShape& inputShape = inputInfo.GetShape();
75 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
76 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
77
78 // Infer numBatches, inputSize, outputSize and numUnits
79 const uint32_t numBatches = inputShape[0];
80 const uint32_t inputSize = inputShape[1];
81 const uint32_t outputSize = outputStateInShape[1];
82 const uint32_t numUnits = cellStateInShape[1];
83
84 // Optional param settings
85 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
86 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
87 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
88 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
89
90 // Input decoders
91 std::unique_ptr<Decoder<float>> inputDecoder =
92 MakeDecoder<float>(inputInfo, inputs[0]->Map());
93 std::unique_ptr<Decoder<float>> outputStateInDecoder =
94 MakeDecoder<float>(outputStateInInfo, inputs[1]->Map());
95 std::unique_ptr<Decoder<float>> cellStateInDecoder =
96 MakeDecoder<float>(cellStateInInfo, inputs[2]->Map());
97
98 // Output decoders
99 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
100 MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map());
101 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
102 MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map());
103 std::unique_ptr<Decoder<float>> outputDecoder =
104 MakeDecoder<float>(outputInfo, outputs[2]->Map());
105
106 // Output encoders
107 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
108 MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map());
109 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
110 MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map());
111 std::unique_ptr<Encoder<float>> outputEncoder =
112 MakeEncoder<float>(outputInfo, outputs[2]->Map());
113
114 // Weights decoders
115 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
116 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
117 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
118 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
119 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
120 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
121
122 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
123 m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
124 m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
125 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
126 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
127 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
128 m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
129 m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
130
131 // Optional CIFG params
132 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
133 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
134 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
135
136 // Optional Peephole params
137 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
138 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
139 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
140
141 // Optional Projection params
142 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
143 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
144
145 // Optional Layer Norm params
146 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
147 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
148 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
149 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
150
151 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
152 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
153 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
154 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
155
156 // Int16 vectors for internal state data (to be decoded/encoded)
157 const uint32_t stateTensorSize = numBatches * numUnits;
158 std::vector<int16_t> inputGateData(stateTensorSize);
159 std::vector<int16_t> cellGateData(stateTensorSize);
160 std::vector<int16_t> forgetGateData(stateTensorSize);
161 std::vector<int16_t> outputGateData(stateTensorSize);
162 std::vector<int32_t> hiddenStateData(stateTensorSize);
163 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
164
165 armnn::TensorInfo inputGateInfo(
166 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
167 armnn::TensorInfo cellGateInfo(
168 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
169 armnn::TensorInfo forgetGateInfo(
170 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
171 armnn::TensorInfo outputGateInfo(
172 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
173 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
174 armnn::DataType::QAsymmS8,
175 m_Data.m_Parameters.m_HiddenStateScale,
176 m_Data.m_Parameters.m_HiddenStateZeroPoint);
177 armnn::TensorInfo outputInt16Info({numBatches , outputSize},
178 armnn::DataType::QSymmS16,
179 outputInfo.GetQuantizationScale(),
180 outputInfo.GetQuantizationOffset());
181
182 // Decoders/Encoders for internal states
183 std::unique_ptr<Decoder<float>> inputGateDecoder =
184 MakeDecoder<float>(inputGateInfo, inputGateData.data());
185 std::unique_ptr<Decoder<float>> cellGateDecoder =
186 MakeDecoder<float>(cellGateInfo, cellGateData.data());
187 std::unique_ptr<Decoder<float>> forgetGateDecoder =
188 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
189 std::unique_ptr<Decoder<float>> outputGateDecoder =
190 MakeDecoder<float>(outputGateInfo, outputGateData.data());
191 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
192 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
193
194 std::unique_ptr<Encoder<float>> inputGateEncoder =
195 MakeEncoder<float>(inputGateInfo, inputGateData.data());
196 std::unique_ptr<Encoder<float>> cellGateEncoder =
197 MakeEncoder<float>(cellGateInfo, cellGateData.data());
198 std::unique_ptr<Encoder<float>> forgetGateEncoder =
199 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
200 std::unique_ptr<Encoder<float>> outputGateEncoder =
201 MakeEncoder<float>(outputGateInfo, outputGateData.data());
202 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
203 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
204
205 // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
206 std::unique_ptr<Decoder<float>> outputInt16Decoder =
207 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
208 std::unique_ptr<Encoder<float>> outputInt16Encoder =
209 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
210
211 // Create decoders for optional params if they are enabled
212 if (!cifgEnabled)
213 {
214 inputToInputWeightsDecoder = MakeDecoder<float>(
215 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
216 recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
217 m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
218 }
219
220 if (peepholeEnabled)
221 {
222 if (!cifgEnabled)
223 {
224 cellToInputWeightsDecoder = MakeDecoder<float>(
225 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
226 }
227 cellToForgetWeightsDecoder = MakeDecoder<float>(
228 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
229 cellToOutputWeightsDecoder = MakeDecoder<float>(
230 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
231 }
232
233 if (projectionEnabled)
234 {
235 projectionWeightsDecoder = MakeDecoder<float>(
236 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
237 if (m_ProjectionBiasTensor)
238 {
239 projectionBiasDecoder = MakeDecoder<float>(
240 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
241 }
242 }
243
244 if (layerNormEnabled)
245 {
246 if (!cifgEnabled)
247 {
248 inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
249 m_InputLayerNormWeightsTensor->GetConstTensor<void>());
250
251 // Bias only used if layer norm enabled
252 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
253 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
254 inputGateBiasDecoder = MakeDecoder<float>(
255 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
256 }
257
258 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
259 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
260 m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
261 cellLayerNormWeightsDecoder = MakeDecoder<float>(
262 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
263 outputLayerNormWeightsDecoder = MakeDecoder<float>(
264 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
265 m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
266
267 // Bias only used if layer norm enabled
268 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
269 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
270 forgetGateBiasDecoder = MakeDecoder<float>(
271 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
272
273 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
274 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
275 cellGateBiasDecoder = MakeDecoder<float>(
276 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
277
278 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
279 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
280 outputGateBiasDecoder = MakeDecoder<float>(
281 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
282 }
283
284 // Initialize internal state tensors with zeroes.
285 if (!cifgEnabled)
286 {
287 ZeroVector(*inputGateEncoder, stateTensorSize);
288 }
289 ZeroVector(*forgetGateEncoder, stateTensorSize);
290 ZeroVector(*cellGateEncoder, stateTensorSize);
291 ZeroVector(*outputGateEncoder, stateTensorSize);
292 ZeroVector(*hiddenStateEncoder, stateTensorSize);
293
294 // Input weights * Input
295 if (!cifgEnabled)
296 {
297 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
298 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
299 }
300
301 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
302 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
303
304 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
305 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
306
307 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
308 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
309
310 // Recurrent weights * OutputStateIn
311 if (!cifgEnabled)
312 {
313 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
314 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
315 }
316
317 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
318 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
319
320 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
321 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
322
323 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
324 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
325
326 // Input gate.
327 if (!cifgEnabled)
328 {
329 if (peepholeEnabled)
330 {
331 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
332 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
333 }
334
335 if (layerNormEnabled)
336 {
337 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
338 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
339 1024);
340 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
341
342 MeanStddevNormalization(*inputGateDecoder,
343 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
344
345 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
346
347 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
348 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
349
350 inputGateInfo.SetQuantizationScale(1.f / 4096);
351 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
352
353 VectorBatchVectorAdd(*inputGateBiasDecoder,
354 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
355
356 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
357 }
358
359 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
360 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
361
362 // Input gate sigmoid
363 Activation(*inputGateDecoder, *inputGateEncoder,
364 TensorInfo({numUnits, numBatches}, internalType),
365 ActivationFunction::Sigmoid, 0, 0);
366
367 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
368 }
369
370 // Forget gate
371 if (peepholeEnabled)
372 {
373 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
374 *cellStateInDecoder, numBatches, *forgetGateEncoder);
375 }
376
377 if (layerNormEnabled)
378 {
379 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
380 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
381 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
382 1024);
383 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
384
385
386
387 MeanStddevNormalization(*forgetGateDecoder,
388 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
389
390
391 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
392
393 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
394 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
395
396
397 // Dequantize layer norm output to (1 / 4096)
398 forgetGateInfo.SetQuantizationScale(1.f / 4096);
399 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
400
401 VectorBatchVectorAdd(*forgetGateBiasDecoder,
402 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
403
404
405 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
406 }
407
408 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
409 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
410
411 // Forget gate sigmoid
412 Activation(*forgetGateDecoder, *forgetGateEncoder,
413 TensorInfo({numUnits, numBatches}, internalType),
414 ActivationFunction::Sigmoid, 0, 0);
415
416 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
417
418 // Cell (Modulation) gate
419 if (layerNormEnabled)
420 {
421 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
422 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
423 1024);
424 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
425
426 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
427
428 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
429
430 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
431 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
432
433 cellGateInfo.SetQuantizationScale(1.f / 4096);
434 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
435
436 VectorBatchVectorAdd(*cellGateBiasDecoder,
437 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
438
439 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
440 }
441
442 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
443 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
444
445 // Cell (Modulation) gate tanH
446 Activation(*cellGateDecoder, *cellGateEncoder,
447 TensorInfo({numUnits, numBatches}, internalType),
448 ActivationFunction::TanH, 1.0f, 1.0f);
449
450 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
451
452 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
453
454 if (cifgEnabled)
455 {
456 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
457 VectorVectorCwiseProductAccumulate(
458 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
459 }
460 else
461 {
462 VectorVectorCwiseProductAccumulate(
463 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
464 }
465
466 // Final cell state out calculated here
467 if (m_Data.m_Parameters.m_CellClip > 0.0)
468 {
469 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
470 }
471
472 // Output gate.
473 if (peepholeEnabled)
474 {
475 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
476 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
477 }
478
479 if (layerNormEnabled)
480 {
481 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
482 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
483 1024);
484 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
485
486 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
487
488 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
489
490 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
491 numBatches, *outputGateEncoder);
492
493 outputGateInfo.SetQuantizationScale(1.f / 4096);
494 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
495
496 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
497
498 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
499 }
500
501 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
502 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
503
504 // Output gate sigmoid
505 Activation(*outputGateDecoder, *outputGateEncoder,
506 TensorInfo({numUnits, numBatches}, internalType),
507 ActivationFunction::Sigmoid, 0, 0);
508
509 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
510
511 // Hidden state tanH
512 Activation(*cellStateOutDecoder, *cellGateEncoder,
513 TensorInfo({numUnits, numBatches}, internalType),
514 ActivationFunction::TanH, 1.0f, 1.0f);
515
516 // Final hidden state output
517 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
518
519 // Projection
520 if (m_Data.m_Parameters.m_ProjectionEnabled)
521 {
522 if (m_ProjectionBiasTensor)
523 {
524 VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
525 }
526
527 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
528 numBatches, *outputInt16Encoder);
529
530 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
531
532 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
533 {
534 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
535 }
536 }
537 else
538 {
539 // Output has same quantization scale as hidden state if projection is disabled
540 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
541 }
542
543 // output == outputStateOut
544 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
545 }
546
547 } //namespace armnn
548