xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/NEON/functions/NEQLSTMLayer.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
25 
26 #include "arm_compute/core/ITensorPack.h"
27 #include "arm_compute/core/KernelDescriptors.h"
28 #include "arm_compute/core/QuantizationInfo.h"
29 #include "arm_compute/core/Utils.h"
30 #include "arm_compute/core/Validate.h"
31 #include "arm_compute/core/utils/misc/InfoHelpers.h"
32 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
33 #include "arm_compute/runtime/NEON/NEScheduler.h"
34 #include "src/common/utils/Log.h"
35 #include "src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
36 #include "src/core/helpers/WindowHelpers.h"
37 #include "src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.h"
38 
39 namespace arm_compute
40 {
41 using namespace arm_compute::utils::info_helpers;
42 namespace
43 {
validate_mm(GEMMLowpOutputStageInfo & gemmlowp_info,const ITensorInfo * mm_input,const ITensorInfo * mm_weights,const ITensorInfo * bias,float gemmlowp_scale,const TensorInfo * mm_res_info,const TensorInfo * outstage_tensor_info)44 Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
45                    float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
46 {
47     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
48     ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
49     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
50     return Status{};
51 }
52 } // namespace
53 
validate_layer_norm(const ITensorInfo & in,const ITensorInfo & weight,const ITensorInfo & bias)54 Status NEQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
55 {
56     // Output quantization scale will be different, but ignored here
57     // since it will be configured at configure() stage.
58     const TensorInfo out
59     {
60         in
61     };
62     return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
63 }
64 
configure_layer_norm(NEQLSTMLayer::LayerNormGate g,const ITensor * in)65 void NEQLSTMLayer::configure_layer_norm(NEQLSTMLayer::LayerNormGate g, const ITensor *in)
66 {
67     ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
68 
69     Tensor &out = get_layer_norm_output(g);
70     _memory_group.manage(&out);
71     out.allocator()->init(*(in->info()));
72 
73     get_layer_norm(g) = std::make_unique<NEQLSTMLayerNormalizationKernel>();
74     get_layer_norm(g)->configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
75 }
76 
77 NEQLSTMLayer::TensorCopyKernel::~TensorCopyKernel() = default;
78 
validate(const ITensorInfo & src,const ITensorInfo & dst)79 Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
80 {
81     ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
82     ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
83     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
84     ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
85     return Status{};
86 }
87 
configure(ITensor & src,ITensor & dst)88 void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
89 {
90     ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
91     ARM_COMPUTE_LOG_PARAMS(src, dst);
92 
93     _src      = &src;
94     _dst      = &dst;
95     _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
96     _window   = calculate_max_window(*_src->info(), Steps());
97 }
98 
run()99 void NEQLSTMLayer::TensorCopyKernel::run()
100 {
101     Iterator input_iter{ _src, _window };
102     Iterator output_iter{ _dst, _window };
103 
104     execute_window_loop(_window, [&](const Coordinates &)
105     {
106         memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
107     },
108     input_iter, output_iter);
109 }
110 
111 NEQLSTMLayer::~NEQLSTMLayer() = default;
112 
NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)113 NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
114     : _memory_group(),
115       _dequantize_input_to_forget_weights(),
116       _quantize_input_to_forget_weights(),
117       _transpose_input_to_forget_weights(),
118       _transpose_input_to_cell_weights(),
119       _transpose_input_to_output_weights(),
120       _transpose_input_to_input_weights(),
121       _transpose_recurrent_to_forget_weights(),
122       _transpose_recurrent_to_cell_weights(),
123       _transpose_recurrent_to_output_weights(),
124       _transpose_recurrent_to_input_weights(),
125       _transpose_projection_weights(),
126       _input_to_input_reduction(),
127       _recurrent_to_input_reduction(),
128       _input_to_forget_reduction(),
129       _recurrent_to_forget_reduction(),
130       _input_to_cell_reduction(),
131       _recurrent_to_cell_reduction(),
132       _input_to_output_reduction(),
133       _recurrent_to_output_reduction(),
134       _projection_reduction(),
135       _projection_bias_add(),
136       _mm_input_to_forget(),
137       _mm_recurrent_to_forget(),
138       _pixelwise_mul_cell_to_forget(),
139       _input_to_forget_outstage(),
140       _recurrent_to_forget_outstage(),
141       _cell_to_forget_outstage(),
142       _accumulate_input_recurrent_forget(),
143       _accumulate_cell_forget(),
144       _forget_gate_sigmoid(),
145       _mm_input_to_cell(),
146       _input_to_cell_outstage(),
147       _mm_recurrent_to_cell(),
148       _recurrent_to_cell_outstage(),
149       _accumulate_input_recurrent_modulation(),
150       _cell_gate_tanh(),
151       _input_gate_sub(),
152       _mm_input_to_input(),
153       _input_to_input_outstage(),
154       _mm_recurrent_to_input(),
155       _recurrent_to_input_outstage(),
156       _accumulate_input_recurrent_input(),
157       _pixelwise_mul_cell_to_input(),
158       _cell_to_input_outstage(),
159       _accumulate_cell_input(),
160       _input_gate_sigmoid(),
161       _pixelwise_mul_forget_cell(),
162       _pixelwise_mul_input_cell(),
163       _add_forget_cell(),
164       _cell_clip(),
165       _mm_input_to_output(),
166       _input_to_output_outstage(),
167       _mm_recurrent_to_output(),
168       _recurrent_to_output_outstage(),
169       _accumulate_input_recurrent_output(),
170       _pixelwise_mul_cell_to_output(),
171       _cell_to_output_outstage(),
172       _accumulate_cell_to_output(),
173       _output_gate_sigmoid(),
174       _hidden_tanh(),
175       _pixelwise_mul_hidden(),
176       _hidden_outstage(),
177       _mm_projection(),
178       _projection_outstage(),
179       _accumulate_projection(),
180       _projection_clip(),
181       _projection_bias_copy(),
182       _projection_output_to_accumulate_copy(),
183       _projection_accumulate_to_output_copy(),
184       _hidden_to_output_copy(),
185       _layer_norms(),
186       _copy_output(),
187       _layer_norm_weights(),
188       _layer_norm_bias(),
189       _layer_norm_output()
190 {
191     _memory_group = MemoryGroup(std::move(memory_manager));
192 }
193 
configure_mm(NEGEMMLowpMatrixMultiplyCore & mm,NEGEMMLowpOutputStage & outstage,GEMMLowpOutputStageInfo & gemmlowp_info,const ITensor * mm_input,const ITensor * mm_weights,const ITensor * bias,Tensor * mm_res,Tensor * outstage_res,float gemmlowp_scale,const TensorInfo & mm_res_info,const TensorInfo & outstage_tensor_info)194 void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
195                                 const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias,
196                                 Tensor *mm_res, Tensor *outstage_res, float gemmlowp_scale,
197                                 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
198 {
199     _memory_group.manage(mm_res);
200     _memory_group.manage(outstage_res);
201 
202     mm_res->allocator()->init(mm_res_info);
203     outstage_res->allocator()->init(outstage_tensor_info);
204 
205     // Configure matrix-multiplication
206     mm.configure(mm_input, mm_weights, nullptr, mm_res);
207 
208     // Configure output stage
209     quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
210     outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
211     mm_res->allocator()->allocate();
212 }
213 
configure(const ITensor * input,const ITensor * input_to_forget_weights,const ITensor * input_to_cell_weights,const ITensor * input_to_output_weights,const ITensor * recurrent_to_forget_weights,const ITensor * recurrent_to_cell_weights,const ITensor * recurrent_to_output_weights,const ITensor * forget_gate_bias,const ITensor * cell_bias,const ITensor * output_gate_bias,const ITensor * cell_state_in,ITensor * output_state_in,ITensor * cell_state_out,ITensor * output_state_out,ITensor * output,const LSTMParams<ITensor> & lstm_params)214 void NEQLSTMLayer::configure(const ITensor *input,
215                              const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
216                              const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
217                              const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
218                              const ITensor *cell_state_in, ITensor *output_state_in,
219                              ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
220                              const LSTMParams<ITensor> &lstm_params)
221 {
222     ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
223                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
224                                  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
225 
226     ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
227                            recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
228                            forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
229 
230     // Set lstm parameters
231     LSTMParams<ITensorInfo> lstm_params_info{};
232     build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
233 
234     _input_to_forget_weights_transposed.info()->set_quantization_info(input_to_forget_weights->info()->quantization_info());
235     _input_to_cell_weights_transposed.info()->set_quantization_info(input_to_cell_weights->info()->quantization_info());
236     _input_to_output_weights_transposed.info()->set_quantization_info(input_to_output_weights->info()->quantization_info());
237     _recurrent_to_forget_weights_transposed.info()->set_quantization_info(recurrent_to_forget_weights->info()->quantization_info());
238     _recurrent_to_cell_weights_transposed.info()->set_quantization_info(recurrent_to_cell_weights->info()->quantization_info());
239     _recurrent_to_output_weights_transposed.info()->set_quantization_info(recurrent_to_output_weights->info()->quantization_info());
240 
241     if(input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED)
242     {
243         _convert_input_to_forget_weights_to_qsymm8 = true;
244         // Setup dequantize output tensor to go from QASYMM8_SIGNED -> F32
245 
246         _input_to_forget_weights_f32.allocator()->init(TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::F32)
247                                                        .set_data_layout(input_to_forget_weights->info()->data_layout()));
248         // Setup the quantize output tensor to go from F32 -> QSYMM8
249         _input_to_forget_weights_symm8.allocator()->init((TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::QSYMM8)
250                                                           .set_data_layout(input_to_forget_weights->info()->data_layout())
251                                                           .set_quantization_info(input_to_forget_weights->info()->quantization_info())));
252 
253         _dequantize_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_f32);
254         _quantize_input_to_forget_weights.configure(&_input_to_forget_weights_f32, &_input_to_forget_weights_symm8);
255 
256         ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), _input_to_forget_weights_symm8.info(), input_to_cell_weights->info(), input_to_output_weights->info(),
257                                                           recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
258                                                           forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
259                                                           cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
260                                                           lstm_params_info));
261     }
262     else
263     {
264         ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
265                                                           recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
266                                                           forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
267                                                           cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
268                                                           lstm_params_info));
269     }
270 
271     const int batch_size  = input->info()->dimension(1);
272     const int num_units   = input_to_output_weights->info()->dimension(1);
273     const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
274 
275     const UniformQuantizationInfo qinput           = input->info()->quantization_info().uniform();
276     const UniformQuantizationInfo qcell_state_in   = cell_state_in->info()->quantization_info().uniform();
277     const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
278 
279     _projection_bias             = lstm_params.projection_bias();
280     _input_to_forget_weights     = (input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED) ? &_input_to_forget_weights_symm8 : input_to_forget_weights;
281     _input_to_cell_weights       = input_to_cell_weights;
282     _input_to_output_weights     = input_to_output_weights;
283     _recurrent_to_forget_weights = recurrent_to_forget_weights;
284     _recurrent_to_cell_weights   = recurrent_to_cell_weights;
285     _recurrent_to_output_weights = recurrent_to_output_weights;
286     _projection_weights          = lstm_params.projection_weights();
287 
288     // Layer normalization
289     _has_layer_norm = lstm_params.use_layer_norm();
290     if(_has_layer_norm)
291     {
292         set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
293         set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
294         set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
295         set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
296 
297         set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
298         set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
299         set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
300         set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
301     }
302 
303     _has_cifg       = lstm_params.has_cifg_opt();
304     _has_projection = lstm_params.has_projection();
305     _has_peephole   = lstm_params.has_peephole_opt();
306 
307     // Calculate and decompose effective scales for optimizing matmul calculation
308     const int32_t cell_shift = log2(qcell_state_in.scale);
309 
310     // Calculate quantized parameters for clipping.
311     int16_t quantized_cell_clip = 0;
312     if(lstm_params.cell_clip() > 0.0f)
313     {
314         quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
315     }
316     _has_cell_clipping = quantized_cell_clip > 0;
317 
318     // Precompute effective bias for optimizing the matmul computations.
319     if(!_has_cifg)
320     {
321         _input_to_input_weights     = lstm_params.input_to_input_weights();
322         _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
323 
324         _input_to_input_reduction     = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
325         _recurrent_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
326         _input_to_input_reduction->configure(_input_to_input_weights->info(), _input_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
327         _recurrent_to_input_reduction->configure(_recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
328     }
329 
330     _input_to_forget_reduction     = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
331     _recurrent_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
332     _input_to_cell_reduction       = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
333     _recurrent_to_cell_reduction   = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
334     _input_to_output_reduction     = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
335     _recurrent_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
336 
337     _input_to_forget_reduction->configure(input_to_forget_weights->info(), _input_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
338     _recurrent_to_forget_reduction->configure(recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
339     _input_to_cell_reduction->configure(input_to_cell_weights->info(), _input_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
340     _recurrent_to_cell_reduction->configure(recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
341     _input_to_output_reduction->configure(input_to_output_weights->info(), _input_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
342     _recurrent_to_output_reduction->configure(recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
343     if(_has_projection)
344     {
345         _projection_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
346         _projection_reduction->configure(_projection_weights->info(), _projection_eff_bias.info(), GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
347         if(_projection_bias != nullptr)
348         {
349             _projection_bias_add.configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
350         }
351     }
352 
353     // Pre-transpose weights to be used in GEMM.
354     _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
355     _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
356     _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
357     _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
358     _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
359     _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
360     if(!_has_cifg)
361     {
362         _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
363         _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
364     }
365     if(_has_projection)
366     {
367         _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
368     }
369 
370     GEMMLowpOutputStageInfo gemmlowp_info;
371     gemmlowp_info.type               = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
372     gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
373     gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
374     gemmlowp_info.output_data_type   = DataType::QSYMM16;
375 
376     const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
377     // Forget gate.
378     const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
379     const float      input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
380     configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
381                  input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
382                  &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
383                  mm_out_info, forget_gate_outstage_info);
384 
385     const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
386     configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
387                  output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
388                  &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
389                  mm_out_info, forget_gate_outstage_info);
390 
391     _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
392     _input_to_forget_outstage_res.allocator()->allocate();
393 
394     if(_has_peephole)
395     {
396         _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
397         _memory_group.manage(&_mul_cell_to_forget_res);
398         _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
399         _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
400         _memory_group.manage(&_cell_to_forget_outstage_res);
401         const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
402         quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
403         _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
404         _mul_cell_to_forget_res.allocator()->allocate();
405         _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
406         _cell_to_forget_outstage_res.allocator()->allocate();
407     }
408 
409     Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
410 
411     if(_has_layer_norm)
412     {
413         configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
414         forget_activation_input->allocator()->allocate();
415         forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
416     }
417 
418     // Output quantization info of Sigmoid and Tanh activations
419     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
420     const TensorInfo       forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
421 
422     _memory_group.manage(&_forget_gate);
423     _forget_gate.allocator()->init(forget_gate_info);
424     _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
425     forget_activation_input->allocator()->allocate();
426 
427     // Modulation gate.
428     const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
429     const float      input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
430     configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
431                  input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
432                  &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
433                  mm_out_info, cell_outstage_info);
434 
435     const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
436     configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
437                  output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
438                  &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
439                  mm_out_info, cell_outstage_info);
440 
441     _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
442     _input_to_cell_outstage_res.allocator()->allocate();
443 
444     Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
445 
446     if(_has_layer_norm)
447     {
448         configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
449         cell_activation_input->allocator()->allocate();
450         cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
451     }
452 
453     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
454 
455     _memory_group.manage(&_cell_gate);
456     _cell_gate.allocator()->init(cell_gate_info);
457     _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
458     cell_activation_input->allocator()->allocate();
459 
460     // Input gate.
461     const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
462     _input_gate.allocator()->init(input_gate_info);
463     _memory_group.manage(&_input_gate);
464     if(_has_cifg)
465     {
466         _ones.allocator()->init(*_forget_gate.info());
467         _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
468         _ones.allocator()->allocate();
469     }
470     else
471     {
472         const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
473         const float      input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
474         configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
475                      input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
476                      &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
477                      mm_out_info, input_outstage_info);
478 
479         const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
480         configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
481                      output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
482                      &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
483                      mm_out_info, input_outstage_info);
484         _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
485         _input_to_input_outstage_res.allocator()->allocate();
486 
487         if(_has_peephole)
488         {
489             _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
490             _memory_group.manage(&_mul_cell_to_input_res);
491             _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
492             const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
493             quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
494             _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
495             _memory_group.manage(&_cell_to_input_outstage_res);
496             _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
497             _mul_cell_to_input_res.allocator()->allocate();
498             _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
499             _cell_to_input_outstage_res.allocator()->allocate();
500         }
501 
502         Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
503 
504         if(_has_layer_norm)
505         {
506             configure_layer_norm(LayerNormGate::Input, input_activation_input);
507             input_activation_input->allocator()->allocate();
508             input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
509         }
510 
511         _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
512         input_activation_input->allocator()->allocate();
513     }
514     // Cell.
515     // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
516     _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
517     const float      cell_gate_scale      = _cell_gate.info()->quantization_info().uniform().scale;
518     const float      mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
519     const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
520     _memory_group.manage(&_mul_input_cell_res);
521     _mul_input_cell_res.allocator()->init(mul_input_cell_info);
522     _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
523     _cell_gate.allocator()->allocate();
524     _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
525     _mul_input_cell_res.allocator()->allocate();
526     _forget_gate.allocator()->allocate();
527     if(_has_cell_clipping)
528     {
529         _cell_clip.configure(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
530     }
531     // Output gate.
532     const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
533     const float      input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
534     configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
535                  input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
536                  &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
537                  mm_out_info, output_outstage_info);
538 
539     const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
540     configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
541                  output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
542                  &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
543                  mm_out_info, output_outstage_info);
544 
545     _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
546     _input_to_output_outstage_res.allocator()->allocate();
547 
548     if(_has_peephole)
549     {
550         // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
551         // Here we are not using the output stage because all operations are done in float
552         _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
553         _memory_group.manage(&_mul_cell_to_output_res);
554         _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
555 
556         const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
557         quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
558         _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
559         _memory_group.manage(&_cell_to_output_outstage_res);
560         _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
561         _mul_cell_to_output_res.allocator()->allocate();
562 
563         _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
564         _cell_to_output_outstage_res.allocator()->allocate();
565     }
566 
567     Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
568 
569     if(_has_layer_norm)
570     {
571         configure_layer_norm(LayerNormGate::Output, output_activation_input);
572         output_activation_input->allocator()->allocate();
573         output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
574     }
575     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
576 
577     _memory_group.manage(&_output_gate);
578     _output_gate.allocator()->init(output_gate_info);
579     _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
580     output_activation_input->allocator()->allocate();
581 
582     // Hidden.
583     _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
584     // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
585     _memory_group.manage(&_hidden_mul_res);
586     const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
587     _hidden_mul_res.allocator()->init(hidden_mul_res);
588     _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
589     _output_gate.allocator()->allocate();
590     _input_gate.allocator()->allocate();
591     const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
592     quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
593     gemmlowp_info.gemmlowp_offset  = lstm_params.hidden_state_zero();
594     gemmlowp_info.output_data_type = output_state_in->info()->data_type();
595 
596     _projection_tensor_copy_required = (num_units != output_size);
597     ITensor *hidden_gate_result      = output_state_out;
598 
599     _memory_group.manage(&_hidden_gate);
600 
601     if(_projection_tensor_copy_required)
602     {
603         _hidden_gate.allocator()->init(*output_state_out->info());
604         _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
605         hidden_gate_result = &_hidden_gate;
606     }
607 
608     _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
609     _hidden_mul_res.allocator()->allocate();
610 
611     // Projection.
612     if(_has_projection)
613     {
614         const TensorInfo              projection_outstage_info(*output_state_out->info());
615         const UniformQuantizationInfo qprojection      = _projection_weights->info()->quantization_info().uniform();
616         const float                   projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
617         gemmlowp_info.gemmlowp_offset                  = qoutput_state_in.offset;
618         gemmlowp_info.gemmlowp_min_bound               = std::numeric_limits<int8_t>::lowest();
619         gemmlowp_info.gemmlowp_max_bound               = std::numeric_limits<int8_t>::max();
620         gemmlowp_info.output_data_type                 = DataType::QASYMM8_SIGNED;
621 
622         TensorInfo projection_mm_out_info{ mm_out_info };
623         projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
624 
625         configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
626                      hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
627                      &_mm_projection_res, &_projection_outstage_res, projection_scale,
628                      projection_mm_out_info, projection_outstage_info);
629 
630         ITensor *accumulate_destination = output_state_out;
631 
632         if(_projection_tensor_copy_required)
633         {
634             _hidden_gate.allocator()->allocate();
635             _projection_accumulate_res.allocator()->init(*output_state_in->info());
636             _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
637             _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
638             accumulate_destination = &_projection_accumulate_res;
639         }
640 
641         _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
642         _projection_outstage_res.allocator()->allocate();
643 
644         if(_projection_tensor_copy_required)
645         {
646             _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
647             _projection_accumulate_res.allocator()->allocate();
648         }
649 
650         int8_t quantized_projection_clip{ 0 };
651         if(lstm_params.projection_clip() > 0.0f)
652         {
653             quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
654         }
655 
656         if(quantized_projection_clip > 0)
657         {
658             _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip, quantized_projection_clip));
659             _has_projection_clipping = true;
660         }
661     }
662     else
663     {
664         if(_projection_tensor_copy_required)
665         {
666             _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
667             _hidden_gate.allocator()->allocate();
668         }
669     }
670 
671     // Copy output_state_out to output
672     _copy_output.configure(output_state_out, output);
673 }
674 
validate(const ITensorInfo * input,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * cell_state_in,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_out,const ITensorInfo * output_state_out,const ITensorInfo * output,const LSTMParams<ITensorInfo> & lstm_params)675 Status NEQLSTMLayer::validate(const ITensorInfo *input,
676                               const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
677                               const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
678                               const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
679                               const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
680                               const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
681                               const LSTMParams<ITensorInfo> &lstm_params)
682 {
683     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
684                                         recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
685                                         cell_state_out, output_state_out, output);
686 
687     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
688     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
689 
690     const unsigned int input_size  = input->dimension(0);
691     const unsigned int batch_size  = input->dimension(1);
692     const unsigned int num_units   = input_to_output_weights->dimension(1);
693     const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
694 
695     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
696     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
697     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
698     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
699     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
700     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
701     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QASYMM8_SIGNED, DataType::QSYMM8);
702 
703     // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
704     if (input_to_forget_weights->data_type() == DataType::QSYMM8)
705     {
706         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_cell_weights, input_to_output_weights,
707                                                            recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
708     }
709     else
710     {
711         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
712                                                            recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
713     }
714     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
715     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
716     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
717     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
718     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
719 
720     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
721     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
722     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
723     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
724 
725     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
726     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
727     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
728     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
729 
730     // Check whether peephole weights are all there or none
731     if(lstm_params.has_peephole_opt())
732     {
733         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
734         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
735         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
736         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
737         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
738         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
739 
740         if(!lstm_params.has_cifg_opt())
741         {
742             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
743             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
744             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
745         }
746     }
747 
748     const UniformQuantizationInfo qinput           = input->quantization_info().uniform();
749     const UniformQuantizationInfo qcell_state_in   = cell_state_in->quantization_info().uniform();
750     const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
751 
752     // Calculate and decompose effective scales for optimizing matmul calculation
753     const int32_t cell_shift = log2(qcell_state_in.scale);
754     ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
755 
756     // Calculate quantized parameters for clipping.
757     int16_t quantized_cell_clip = 0;
758     if(lstm_params.cell_clip() > 0.0f)
759     {
760         quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
761     }
762 
763     // Precompute effective bias for optimizing the matmul computations.
764     const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
765     const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
766     if(!lstm_params.has_cifg_opt())
767     {
768         ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
769                                                                                               -qinput.offset, true)));
770         ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
771                                                                                               -qoutput_state_in.offset,
772                                                                                               true)));
773     }
774     ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
775     ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
776                                                                                           -qoutput_state_in.offset, true)));
777     ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
778     ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
779                                                                                           true)));
780     ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
781     ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
782                                                                                           -qoutput_state_in.offset, true)));
783     if(lstm_params.has_projection())
784     {
785         ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
786                                                                                               lstm_params.hidden_state_zero(),
787                                                                                               true)));
788         if(lstm_params.projection_bias() != nullptr)
789         {
790             ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
791             ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info, &projection_eff_bias_info, ConvertPolicy::SATURATE));
792         }
793     }
794 
795     const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_cell_weights->data_type(), input_to_cell_weights->quantization_info());
796     const TensorInfo input_to_output_weights_transposed(TensorShape(num_units, input_size), 1, input_to_output_weights->data_type(), input_to_output_weights->quantization_info());
797     const TensorInfo recurrent_to_forget_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
798     const TensorInfo recurrent_to_cell_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_cell_weights->data_type(), recurrent_to_cell_weights->quantization_info());
799     const TensorInfo recurrent_to_output_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_output_weights->data_type(), recurrent_to_output_weights->quantization_info());
800     const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
801 
802     ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed));
803     ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_to_output_weights_transposed));
804     ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_to_forget_weights_transposed));
805     ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_to_cell_weights_transposed));
806     ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_to_output_weights_transposed));
807     if(!lstm_params.has_cifg_opt())
808     {
809         const TensorInfo recurrent_to_input_weights_transposed(TensorShape(num_units, output_size), 1,
810                                                                recurrent_to_forget_weights->data_type(), lstm_params.recurrent_to_input_weights()->quantization_info());
811         const TensorInfo input_to_input_weights_transposed(TensorShape(num_units, input_size), 1,
812                                                            lstm_params.input_to_input_weights()->data_type(), lstm_params.input_to_input_weights()->quantization_info());
813         ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_to_input_weights_transposed));
814         ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_to_input_weights_transposed));
815     }
816     if(lstm_params.has_projection())
817     {
818         const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
819         ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
820     }
821 
822     GEMMLowpOutputStageInfo gemmlowp_info;
823     gemmlowp_info.type               = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
824     gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
825     gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
826     gemmlowp_info.output_data_type   = DataType::QSYMM16;
827 
828     const bool has_layer_norm = lstm_params.use_layer_norm();
829 
830     // Forget gate.
831     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
832     const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
833     const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
834     const float      input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
835     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
836 
837     const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
838     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
839 
840     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
841 
842     if(lstm_params.has_peephole_opt())
843     {
844         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
845         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
846                                                                         RoundingPolicy::TO_ZERO));
847         const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
848         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
849         ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
850         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
851     }
852 
853     if(has_layer_norm)
854     {
855         const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
856         const ITensorInfo *b_info = forget_gate_bias;
857         ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
858     }
859 
860     // Output quantization info of Sigmoid and Tanh activations
861     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
862     const TensorInfo       forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
863 
864     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
865 
866     // Modulation gate.
867     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
868     const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
869     const float      input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
870     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
871 
872     const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
873     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
874 
875     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
876 
877     if(has_layer_norm)
878     {
879         const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
880         const ITensorInfo *b_info = cell_bias;
881         ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
882     }
883     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
884 
885     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
886 
887     // Input gate.
888     const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
889     if(lstm_params.has_cifg_opt())
890     {
891         ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
892         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
893     }
894     else
895     {
896         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
897 
898         // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
899         if (input_to_forget_weights->data_type() == DataType::QSYMM8)
900         {
901             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
902         }
903         else
904         {
905             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
906         }
907         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
908         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
909         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
910         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
911 
912         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
913         const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
914         const float      input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
915         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
916 
917         const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
918         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
919 
920         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
921 
922         if(lstm_params.has_peephole_opt())
923         {
924             ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
925                                                                             RoundingPolicy::TO_ZERO));
926             const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
927             ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
928             ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
929             ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
930         }
931 
932         if(has_layer_norm)
933         {
934             const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
935             const ITensorInfo *b_info = lstm_params.input_gate_bias();
936             ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
937         }
938 
939         ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
940     }
941     // Cell.
942     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
943     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
944     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
945     if(quantized_cell_clip > 0)
946     {
947         ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
948                                                                                                              quantized_cell_clip)));
949     }
950     // Output gate.
951     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
952     const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
953     const float      input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
954     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
955 
956     const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
957     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
958 
959     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
960     if(lstm_params.has_peephole_opt())
961     {
962         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
963         // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
964         // Here we are not using the output stage because all operations are done in float
965         // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
966         // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
967         ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
968                                                                         RoundingPolicy::TO_ZERO));
969         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
970     }
971 
972     if(has_layer_norm)
973     {
974         const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
975         const ITensorInfo *b_info = output_gate_bias;
976         ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
977     }
978 
979     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
980     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
981 
982     // Hidden.
983     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
984     const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
985     const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
986     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
987 
988     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
989     const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
990     ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
991     gemmlowp_info.gemmlowp_offset  = lstm_params.hidden_state_zero();
992     gemmlowp_info.output_data_type = hidden_out_info.data_type();
993     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
994 
995     const bool projection_tensor_copy_required = num_units != output_size;
996 
997     // Projection.
998     if(lstm_params.has_projection())
999     {
1000         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
1001         ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
1002 
1003         const UniformQuantizationInfo qprojection      = lstm_params.projection_weights()->quantization_info().uniform();
1004         const float                   projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
1005         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
1006         gemmlowp_info.gemmlowp_offset    = qoutput_state_in.offset;
1007         gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
1008         gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
1009         gemmlowp_info.output_data_type   = DataType::QASYMM8_SIGNED;
1010 
1011         const TensorInfo projection_outstage_info(*output_state_out);
1012         const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
1013 
1014         TensorInfo projection_mm_out_info{ mm_out_info };
1015         projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
1016 
1017         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
1018                                                 &projection_outstage_info));
1019 
1020         if(projection_tensor_copy_required)
1021         {
1022             ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
1023         }
1024 
1025         ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
1026 
1027         if(projection_tensor_copy_required)
1028         {
1029             ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
1030         }
1031 
1032         int8_t quantized_projection_clip{ 0 };
1033         if(lstm_params.projection_clip() > 0.0f)
1034         {
1035             quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
1036         }
1037 
1038         if(quantized_projection_clip > 0)
1039         {
1040             ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
1041                                                                                                                    quantized_projection_clip)));
1042         }
1043     }
1044     else
1045     {
1046         if(projection_tensor_copy_required)
1047         {
1048             ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
1049         }
1050     }
1051 
1052     if(cell_state_out->total_size() > 0)
1053     {
1054         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
1055         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
1056     }
1057 
1058     if(output_state_out->total_size() > 0)
1059     {
1060         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
1061         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
1062     }
1063 
1064     ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
1065     return Status{};
1066 }
1067 
run()1068 void NEQLSTMLayer::run()
1069 {
1070     prepare();
1071 
1072     // Acquire all the temporaries
1073     MemoryGroupResourceScope scope_mg(_memory_group);
1074 
1075     // Forget gate.
1076     _mm_input_to_forget.run();
1077     _input_to_forget_outstage.run();
1078 
1079     _mm_recurrent_to_forget.run();
1080     _recurrent_to_forget_outstage.run();
1081     _accumulate_input_recurrent_forget.run();
1082 
1083     if(_has_peephole)
1084     {
1085         _pixelwise_mul_cell_to_forget.run();
1086         _cell_to_forget_outstage.run();
1087         _accumulate_cell_forget.run();
1088     }
1089 
1090     if(_has_layer_norm)
1091     {
1092         NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Forget).get(), Window::DimY);
1093     }
1094 
1095     _forget_gate_sigmoid.run();
1096 
1097     // Modulation gate.
1098     _mm_input_to_cell.run();
1099     _input_to_cell_outstage.run();
1100 
1101     _mm_recurrent_to_cell.run();
1102     _recurrent_to_cell_outstage.run();
1103     _accumulate_input_recurrent_modulation.run();
1104 
1105     if(_has_layer_norm)
1106     {
1107         NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Cell).get(), Window::DimY);
1108     }
1109 
1110     _cell_gate_tanh.run();
1111 
1112     // Input gate
1113     if(_has_cifg)
1114     {
1115         _input_gate_sub.run();
1116     }
1117     else
1118     {
1119         _mm_input_to_input.run();
1120         _input_to_input_outstage.run();
1121         _mm_recurrent_to_input.run();
1122         _recurrent_to_input_outstage.run();
1123         _accumulate_input_recurrent_input.run();
1124 
1125         if(_has_peephole)
1126         {
1127             _pixelwise_mul_cell_to_input.run();
1128             _cell_to_input_outstage.run();
1129             _accumulate_cell_input.run();
1130         }
1131 
1132         if(_has_layer_norm)
1133         {
1134             NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Input).get(), Window::DimY);
1135         }
1136 
1137         _input_gate_sigmoid.run();
1138     }
1139 
1140     // Cell.
1141     _pixelwise_mul_forget_cell.run();
1142     _pixelwise_mul_input_cell.run();
1143     _add_forget_cell.run();
1144 
1145     if(_has_cell_clipping)
1146     {
1147         _cell_clip.run();
1148     }
1149 
1150     // Output gate.
1151     _mm_input_to_output.run();
1152     _input_to_output_outstage.run();
1153     _mm_recurrent_to_output.run();
1154     _recurrent_to_output_outstage.run();
1155     _accumulate_input_recurrent_output.run();
1156     if(_has_peephole)
1157     {
1158         _pixelwise_mul_cell_to_output.run();
1159         _cell_to_output_outstage.run();
1160         _accumulate_cell_to_output.run();
1161     }
1162 
1163     if(_has_layer_norm)
1164     {
1165         NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Output).get(), Window::DimY);
1166     }
1167 
1168     _output_gate_sigmoid.run();
1169 
1170     // Hidden.
1171     _hidden_tanh.run();
1172     _pixelwise_mul_hidden.run();
1173     _hidden_outstage.run();
1174 
1175     // Projection.
1176     if(_has_projection)
1177     {
1178         _mm_projection.run();
1179         _projection_outstage.run();
1180 
1181         if(_projection_tensor_copy_required)
1182         {
1183             _projection_output_to_accumulate_copy.run();
1184         }
1185 
1186         _accumulate_projection.run();
1187 
1188         if(_projection_tensor_copy_required)
1189         {
1190             _projection_accumulate_to_output_copy.run();
1191         }
1192 
1193         if(_has_projection_clipping)
1194         {
1195             _projection_clip.run();
1196         }
1197     }
1198     else
1199     {
1200         if(_projection_tensor_copy_required)
1201         {
1202             _hidden_to_output_copy.run();
1203         }
1204     }
1205 
1206     // Copy output_state_out to output
1207     _copy_output.run();
1208 }
1209 
prepare()1210 void NEQLSTMLayer::prepare()
1211 {
1212     if(!_is_prepared)
1213     {
1214         if(_convert_input_to_forget_weights_to_qsymm8)
1215         {
1216             _input_to_forget_weights_f32.allocator()->allocate();
1217             _input_to_forget_weights_symm8.allocator()->allocate();
1218             _dequantize_input_to_forget_weights.run();
1219             _quantize_input_to_forget_weights.run();
1220         }
1221 
1222         // Pre-transpose weights to be used in GEMM.
1223         _input_to_forget_weights_transposed.allocator()->allocate();
1224         _input_to_cell_weights_transposed.allocator()->allocate();
1225         _input_to_output_weights_transposed.allocator()->allocate();
1226         _recurrent_to_forget_weights_transposed.allocator()->allocate();
1227         _recurrent_to_cell_weights_transposed.allocator()->allocate();
1228         _recurrent_to_output_weights_transposed.allocator()->allocate();
1229         _transpose_input_to_forget_weights.run();
1230         _transpose_input_to_cell_weights.run();
1231         _transpose_input_to_output_weights.run();
1232         _transpose_recurrent_to_forget_weights.run();
1233         _transpose_recurrent_to_cell_weights.run();
1234         _transpose_recurrent_to_output_weights.run();
1235 
1236         // Precompute effective biases
1237         if(_has_cifg)
1238         {
1239             std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1240         }
1241         else
1242         {
1243             _input_to_input_eff_bias.allocator()->allocate();
1244             _recurrent_to_input_eff_bias.allocator()->allocate();
1245 
1246             ITensorPack packII =
1247             {
1248                 { TensorType::ACL_SRC, _input_to_input_weights },
1249                 { TensorType::ACL_DST, &_input_to_input_eff_bias }
1250             };
1251             NEScheduler::get().schedule_op(_input_to_input_reduction.get(), Window::DimY, _input_to_input_reduction->window(), packII);
1252 
1253             ITensorPack packRI =
1254             {
1255                 { TensorType::ACL_SRC, _recurrent_to_input_weights },
1256                 { TensorType::ACL_DST, &_recurrent_to_input_eff_bias }
1257             };
1258             NEScheduler::get().schedule_op(_recurrent_to_input_reduction.get(), Window::DimY, _recurrent_to_input_reduction->window(), packRI);
1259 
1260             _input_to_input_weights_transposed.allocator()->allocate();
1261             _recurrent_to_input_weights_transposed.allocator()->allocate();
1262             _transpose_input_to_input_weights.run();
1263             _transpose_recurrent_to_input_weights.run();
1264             _input_to_input_weights->mark_as_unused();
1265             _recurrent_to_input_weights->mark_as_unused();
1266         }
1267         _input_to_forget_eff_bias.allocator()->allocate();
1268         _recurrent_to_forget_eff_bias.allocator()->allocate();
1269         _input_to_cell_eff_bias.allocator()->allocate();
1270         _recurrent_to_cell_eff_bias.allocator()->allocate();
1271         _input_to_output_eff_bias.allocator()->allocate();
1272         _recurrent_to_output_eff_bias.allocator()->allocate();
1273 
1274         ITensorPack packIF =
1275         {
1276             { TensorType::ACL_SRC, _input_to_forget_weights },
1277             { TensorType::ACL_DST, &_input_to_forget_eff_bias }
1278         };
1279         NEScheduler::get().schedule_op(_input_to_forget_reduction.get(), Window::DimY, _input_to_forget_reduction->window(), packIF);
1280 
1281         ITensorPack packRF =
1282         {
1283             { TensorType::ACL_SRC, _recurrent_to_forget_weights },
1284             { TensorType::ACL_DST, &_recurrent_to_forget_eff_bias }
1285         };
1286         NEScheduler::get().schedule_op(_recurrent_to_forget_reduction.get(), Window::DimY, _recurrent_to_forget_reduction->window(), packRF);
1287 
1288         ITensorPack packIC =
1289         {
1290             { TensorType::ACL_SRC, _input_to_cell_weights },
1291             { TensorType::ACL_DST, &_input_to_cell_eff_bias }
1292         };
1293         NEScheduler::get().schedule_op(_input_to_cell_reduction.get(), Window::DimY, _input_to_cell_reduction->window(), packIC);
1294 
1295         ITensorPack packRC =
1296         {
1297             { TensorType::ACL_SRC, _recurrent_to_cell_weights },
1298             { TensorType::ACL_DST, &_recurrent_to_cell_eff_bias }
1299         };
1300         NEScheduler::get().schedule_op(_recurrent_to_cell_reduction.get(), Window::DimY, _recurrent_to_cell_reduction->window(), packRC);
1301 
1302         ITensorPack packIO =
1303         {
1304             { TensorType::ACL_SRC, _input_to_output_weights },
1305             { TensorType::ACL_DST, &_input_to_output_eff_bias }
1306         };
1307         NEScheduler::get().schedule_op(_input_to_output_reduction.get(), Window::DimY, _input_to_output_reduction->window(), packIO);
1308 
1309         ITensorPack packRO =
1310         {
1311             { TensorType::ACL_SRC, _recurrent_to_output_weights },
1312             { TensorType::ACL_DST, &_recurrent_to_output_eff_bias }
1313         };
1314         NEScheduler::get().schedule_op(_recurrent_to_output_reduction.get(), Window::DimY, _recurrent_to_output_reduction->window(), packRO);
1315 
1316         if(_has_projection)
1317         {
1318             _projection_eff_bias.allocator()->allocate();
1319             ITensorPack pack =
1320             {
1321                 { TensorType::ACL_SRC, _projection_weights },
1322                 { TensorType::ACL_DST, &_projection_eff_bias }
1323             };
1324             NEScheduler::get().schedule_op(_projection_reduction.get(), Window::DimY, _projection_reduction->window(), pack);
1325             if(_projection_bias != nullptr)
1326             {
1327                 _projection_bias_add.run();
1328                 _projection_bias->mark_as_unused();
1329             }
1330 
1331             _projection_weights_transposed.allocator()->allocate();
1332             _transpose_projection_weights.run();
1333             _projection_weights->mark_as_unused();
1334 
1335             if(!_projection_tensor_copy_required)
1336             {
1337                 _hidden_gate.mark_as_unused();
1338                 _projection_accumulate_res.mark_as_unused();
1339             }
1340         }
1341 
1342         // Mark weights as unused
1343         _input_to_forget_weights->mark_as_unused();
1344         _input_to_cell_weights->mark_as_unused();
1345         _input_to_output_weights->mark_as_unused();
1346         _recurrent_to_forget_weights->mark_as_unused();
1347         _recurrent_to_cell_weights->mark_as_unused();
1348         _recurrent_to_output_weights->mark_as_unused();
1349 
1350         _is_prepared = true;
1351     }
1352 }
1353 } // namespace arm_compute
1354