xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/CL/functions/CLQLSTMLayer.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/CL/functions/CLQLSTMLayer.h"
25 
26 #include "arm_compute/core/KernelDescriptors.h"
27 #include "arm_compute/core/QuantizationInfo.h"
28 #include "arm_compute/core/Utils.h"
29 #include "arm_compute/core/Validate.h"
30 #include "arm_compute/core/utils/misc/InfoHelpers.h"
31 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
32 #include "arm_compute/runtime/CL/CLScheduler.h"
33 #include "src/core/CL/kernels/CLFillBorderKernel.h"
34 #include "src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
35 #include "src/core/helpers/WindowHelpers.h"
36 #include "src/gpu/cl/kernels/ClGemmLowpReductionKernel.h"
37 
38 #include "src/common/utils/Log.h"
39 
40 namespace arm_compute
41 {
42 using namespace arm_compute::utils::info_helpers;
43 using namespace arm_compute::opencl::kernels;
44 namespace
45 {
validate_mm(GEMMLowpOutputStageInfo & gemmlowp_info,const ITensorInfo * mm_input,const ITensorInfo * mm_weights,const ITensorInfo * bias,float gemmlowp_scale,const TensorInfo * mm_res_info,const TensorInfo * outstage_tensor_info)46 Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
47                    float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
48 {
49     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
50     ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
51     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
52     return Status{};
53 }
54 } // namespace
55 
validate(const ITensorInfo & src,const ITensorInfo & dst)56 Status CLQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
57 {
58     ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
59     ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
60     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
61     ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
62     return Status{};
63 }
64 
configure(ICLTensor & src,ICLTensor & dst)65 void CLQLSTMLayer::TensorCopyKernel::configure(ICLTensor &src, ICLTensor &dst)
66 {
67     ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
68     _src      = &src;
69     _dst      = &dst;
70     _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
71     _window   = calculate_max_window(*_src->info(), Steps());
72 }
73 
run()74 void CLQLSTMLayer::TensorCopyKernel::run()
75 {
76     auto &q = CLScheduler::get().queue();
77 
78     _src->map(q, true);
79     _dst->map(q, true);
80 
81     Iterator input_iter{ _src, _window };
82     Iterator output_iter{ _dst, _window };
83 
84     execute_window_loop(_window, [&](const Coordinates &)
85     {
86         memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
87     },
88     input_iter, output_iter);
89 
90     _src->unmap(q);
91     _dst->unmap(q);
92 }
93 
CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)94 CLQLSTMLayer::CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
95     : _input_to_input_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
96       _recurrent_to_input_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
97       _input_to_forget_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
98       _recurrent_to_forget_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
99       _input_to_cell_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
100       _recurrent_to_cell_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
101       _input_to_output_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
102       _recurrent_to_output_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
103       _projection_reduction(std::make_unique<ClGemmLowpMatrixAReductionKernel>()),
104       _layer_norms(),
105       _copy_output()
106 {
107     for(auto &norm : _layer_norms)
108     {
109         norm = std::make_unique<CLQLSTMLayerNormalizationKernel>();
110     }
111 
112     _memory_group = MemoryGroup(std::move(memory_manager));
113 }
114 
115 CLQLSTMLayer::~CLQLSTMLayer() = default;
116 
configure_layer_norm(LayerNormGate g,const ICLTensor * in)117 void CLQLSTMLayer::configure_layer_norm(LayerNormGate g, const ICLTensor *in)
118 {
119     ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
120 
121     CLTensor *out = &get_layer_norm_output(g);
122     _memory_group.manage(out);
123     out->allocator()->init(*(in->info()));
124 
125     get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g));
126 }
127 
validate_layer_norm(const ITensorInfo & in,const ITensorInfo & weight,const ITensorInfo & bias)128 Status CLQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
129 {
130     // Output quantization scale will be different, but ignored here
131     // since it will be configured at configure() stage.
132     const TensorInfo out
133     {
134         in
135     };
136     return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
137 }
138 
configure_mm(const CLCompileContext & compile_context,CLGEMMLowpMatrixMultiplyCore & mm,CLGEMMLowpOutputStage & outstage,GEMMLowpOutputStageInfo & gemmlowp_info,const ICLTensor * mm_input,const ICLTensor * mm_weights,const ICLTensor * bias,CLTensor * mm_res,CLTensor * outstage_res,float gemmlowp_scale,const TensorInfo & mm_res_info,const TensorInfo & outstage_tensor_info)139 void CLQLSTMLayer::configure_mm(const CLCompileContext &compile_context, CLGEMMLowpMatrixMultiplyCore &mm, CLGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
140                                 const ICLTensor *mm_input, const ICLTensor *mm_weights, const ICLTensor *bias,
141                                 CLTensor *mm_res, CLTensor *outstage_res, float gemmlowp_scale,
142                                 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
143 {
144     _memory_group.manage(mm_res);
145     _memory_group.manage(outstage_res);
146 
147     mm_res->allocator()->init(mm_res_info);
148     outstage_res->allocator()->init(outstage_tensor_info);
149 
150     // Configure matrix-multiplication
151     mm.configure(compile_context, mm_input, mm_weights, nullptr, mm_res);
152 
153     // Configure output stage
154     quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
155     outstage.configure(compile_context, mm_res, bias, outstage_res, gemmlowp_info);
156     mm_res->allocator()->allocate();
157 }
158 
configure(const ICLTensor * input,const ICLTensor * input_to_forget_weights,const ICLTensor * input_to_cell_weights,const ICLTensor * input_to_output_weights,const ICLTensor * recurrent_to_forget_weights,const ICLTensor * recurrent_to_cell_weights,const ICLTensor * recurrent_to_output_weights,const ICLTensor * forget_gate_bias,const ICLTensor * cell_bias,const ICLTensor * output_gate_bias,ICLTensor * cell_state_in,ICLTensor * output_state_in,ICLTensor * cell_state_out,ICLTensor * output_state_out,ICLTensor * output,const LSTMParams<ICLTensor> & lstm_params)159 void CLQLSTMLayer::configure(const ICLTensor *input,
160                              const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
161                              const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
162                              const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
163                              ICLTensor *cell_state_in, ICLTensor *output_state_in,
164                              ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
165                              const LSTMParams<ICLTensor> &lstm_params)
166 {
167     configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
168               recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
169               cell_state_in, output_state_in, cell_state_out, output_state_out, output, lstm_params);
170 }
171 
configure(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * input_to_forget_weights,const ICLTensor * input_to_cell_weights,const ICLTensor * input_to_output_weights,const ICLTensor * recurrent_to_forget_weights,const ICLTensor * recurrent_to_cell_weights,const ICLTensor * recurrent_to_output_weights,const ICLTensor * forget_gate_bias,const ICLTensor * cell_bias,const ICLTensor * output_gate_bias,ICLTensor * cell_state_in,ICLTensor * output_state_in,ICLTensor * cell_state_out,ICLTensor * output_state_out,ICLTensor * output,const LSTMParams<ICLTensor> & lstm_params)172 void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
173                              const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
174                              const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
175                              const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
176                              ICLTensor *cell_state_in, ICLTensor *output_state_in,
177                              ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
178                              const LSTMParams<ICLTensor> &lstm_params)
179 {
180     ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
181                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
182                                  forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
183                                  cell_state_out, output_state_out, output);
184 
185     ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
186                            recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
187                            forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
188                            cell_state_out, output_state_out, output, lstm_params);
189     // Set lstm parameters
190     LSTMParams<ITensorInfo> lstm_params_info{};
191     build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
192 
193     // Validate
194     ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
195                                                       recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
196                                                       forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
197                                                       cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
198                                                       lstm_params_info));
199 
200     const int batch_size  = input->info()->dimension(1);
201     const int num_units   = input_to_output_weights->info()->dimension(1);
202     const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
203 
204     const UniformQuantizationInfo qinput           = input->info()->quantization_info().uniform();
205     const UniformQuantizationInfo qcell_state_in   = cell_state_in->info()->quantization_info().uniform();
206     const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
207 
208     _projection_bias             = lstm_params.projection_bias();
209     _input_to_forget_weights     = input_to_forget_weights;
210     _input_to_cell_weights       = input_to_cell_weights;
211     _input_to_output_weights     = input_to_output_weights;
212     _recurrent_to_forget_weights = recurrent_to_forget_weights;
213     _recurrent_to_cell_weights   = recurrent_to_cell_weights;
214     _recurrent_to_output_weights = recurrent_to_output_weights;
215     _projection_weights          = lstm_params.projection_weights();
216 
217     // Layer normalization
218     _has_layer_norm = lstm_params.use_layer_norm();
219     if(_has_layer_norm)
220     {
221         set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
222         set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
223         set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
224         set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
225 
226         set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
227         set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
228         set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
229         set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
230     }
231 
232     _has_cifg       = lstm_params.has_cifg_opt();
233     _has_projection = lstm_params.has_projection();
234     _has_peephole   = lstm_params.has_peephole_opt();
235 
236     // Calculate and decompose effective scales for optimizing matmul calculation
237     const int32_t cell_shift = log2(qcell_state_in.scale);
238 
239     // Calculate quantized parameters for clipping.
240     int16_t quantized_cell_clip = 0;
241     if(lstm_params.cell_clip() > 0.0f)
242     {
243         quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
244     }
245     _has_cell_clipping = quantized_cell_clip > 0;
246 
247     // Precompute effective bias for optimizing the matmul computations.
248     if(!_has_cifg)
249     {
250         _input_to_input_weights     = lstm_params.input_to_input_weights();
251         _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
252 
253         _input_to_input_reduction->configure(compile_context, _input_to_input_weights->info(), _input_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
254         _recurrent_to_input_reduction->configure(compile_context, _recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false,
255                                                  -qoutput_state_in.offset, true));
256     }
257     _input_to_forget_reduction->configure(compile_context, input_to_forget_weights->info(), _input_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
258     _recurrent_to_forget_reduction->configure(compile_context, recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false,
259                                               -qoutput_state_in.offset, true));
260     _input_to_cell_reduction->configure(compile_context, input_to_cell_weights->info(), _input_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
261     _recurrent_to_cell_reduction->configure(compile_context, recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
262                                             true));
263     _input_to_output_reduction->configure(compile_context, input_to_output_weights->info(), _input_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
264     _recurrent_to_output_reduction->configure(compile_context, recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false,
265                                               -qoutput_state_in.offset, true));
266     if(_has_projection)
267     {
268         _projection_reduction->configure(compile_context, _projection_weights->info(), _projection_eff_bias.info(), GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
269         if(_projection_bias != nullptr)
270         {
271             _projection_bias_add.configure(compile_context, _projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
272         }
273     }
274 
275     // Pre-transpose weights to be used in GEMM.
276     _transpose_input_to_forget_weights.configure(compile_context, input_to_forget_weights, &_input_to_forget_weights_transposed);
277     _transpose_input_to_cell_weights.configure(compile_context, input_to_cell_weights, &_input_to_cell_weights_transposed);
278     _transpose_input_to_output_weights.configure(compile_context, input_to_output_weights, &_input_to_output_weights_transposed);
279     _transpose_recurrent_to_forget_weights.configure(compile_context, recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
280     _transpose_recurrent_to_cell_weights.configure(compile_context, recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
281     _transpose_recurrent_to_output_weights.configure(compile_context, recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
282     if(!_has_cifg)
283     {
284         _transpose_input_to_input_weights.configure(compile_context, lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
285         _transpose_recurrent_to_input_weights.configure(compile_context, lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
286     }
287     if(_has_projection)
288     {
289         _transpose_projection_weights.configure(compile_context, _projection_weights, &_projection_weights_transposed);
290     }
291 
292     GEMMLowpOutputStageInfo gemmlowp_info;
293     gemmlowp_info.type               = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
294     gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
295     gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
296     gemmlowp_info.output_data_type   = DataType::QSYMM16;
297 
298     const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
299     // Forget gate.
300     const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
301     const float      input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
302     configure_mm(compile_context, _mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
303                  input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
304                  &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
305                  mm_out_info, forget_gate_outstage_info);
306 
307     const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
308     configure_mm(compile_context, _mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
309                  output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
310                  &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
311                  mm_out_info, forget_gate_outstage_info);
312 
313     _accumulate_input_recurrent_forget.configure(compile_context, &_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
314                                                  ConvertPolicy::SATURATE);
315     _input_to_forget_outstage_res.allocator()->allocate();
316 
317     if(_has_peephole)
318     {
319         _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
320         _memory_group.manage(&_mul_cell_to_forget_res);
321         _pixelwise_mul_cell_to_forget.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
322         _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
323         _memory_group.manage(&_cell_to_forget_outstage_res);
324         const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
325         quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
326         _cell_to_forget_outstage.configure(compile_context, &_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
327         _mul_cell_to_forget_res.allocator()->allocate();
328         _accumulate_cell_forget.configure(compile_context, &_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res,
329                                           ConvertPolicy::SATURATE);
330         _cell_to_forget_outstage_res.allocator()->allocate();
331     }
332 
333     CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
334 
335     if(_has_layer_norm)
336     {
337         configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
338         _recurrent_to_forget_outstage_res.allocator()->allocate();
339         forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
340     }
341 
342     // Output quantization info of Sigmoid and Tanh activations
343     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
344 
345     const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
346     _memory_group.manage(&_forget_gate);
347     _forget_gate.allocator()->init(forget_gate_info);
348     _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
349     forget_activation_input->allocator()->allocate();
350 
351     // Modulation gate.
352     const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
353     const float      input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
354     configure_mm(compile_context, _mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
355                  input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
356                  &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
357                  mm_out_info, cell_outstage_info);
358 
359     const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
360     configure_mm(compile_context, _mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
361                  output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
362                  &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
363                  mm_out_info, cell_outstage_info);
364 
365     _accumulate_input_recurrent_modulation.configure(compile_context, &_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res,
366                                                      ConvertPolicy::SATURATE);
367     _input_to_cell_outstage_res.allocator()->allocate();
368 
369     CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
370 
371     if(_has_layer_norm)
372     {
373         configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
374         _recurrent_to_cell_outstage_res.allocator()->allocate();
375         cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
376     }
377 
378     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
379     _memory_group.manage(&_cell_gate);
380     _cell_gate.allocator()->init(cell_gate_info);
381     _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
382     cell_activation_input->allocator()->allocate();
383 
384     // Input gate.
385     const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
386     _input_gate.allocator()->init(input_gate_info);
387     _memory_group.manage(&_input_gate);
388     if(_has_cifg)
389     {
390         _ones.allocator()->init(*_forget_gate.info());
391         _input_gate_sub.configure(compile_context, &_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
392         _ones.allocator()->allocate();
393     }
394     else
395     {
396         const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
397         const float      input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
398         configure_mm(compile_context, _mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
399                      input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
400                      &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
401                      mm_out_info, input_outstage_info);
402 
403         const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
404         configure_mm(compile_context, _mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
405                      output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
406                      &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
407                      mm_out_info, input_outstage_info);
408         _accumulate_input_recurrent_input.configure(compile_context, &_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res,
409                                                     ConvertPolicy::SATURATE);
410         _input_to_input_outstage_res.allocator()->allocate();
411 
412         if(_has_peephole)
413         {
414             _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
415             _memory_group.manage(&_mul_cell_to_input_res);
416             _pixelwise_mul_cell_to_input.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
417             const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
418             quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
419             _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
420             _memory_group.manage(&_cell_to_input_outstage_res);
421             _cell_to_input_outstage.configure(compile_context, &_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
422             _mul_cell_to_input_res.allocator()->allocate();
423             _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
424             _cell_to_input_outstage_res.allocator()->allocate();
425         }
426 
427         CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
428 
429         if(_has_layer_norm)
430         {
431             configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
432             _recurrent_to_input_outstage_res.allocator()->allocate();
433             input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
434         }
435 
436         _input_gate_sigmoid.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
437         input_activation_input->allocator()->allocate();
438     }
439     // Cell.
440     // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
441     _pixelwise_mul_forget_cell.configure(compile_context, &_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
442     const float      cell_gate_scale      = _cell_gate.info()->quantization_info().uniform().scale;
443     const float      mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
444     const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
445     _memory_group.manage(&_mul_input_cell_res);
446     _mul_input_cell_res.allocator()->init(mul_input_cell_info);
447     _pixelwise_mul_input_cell.configure(compile_context, &_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
448     _cell_gate.allocator()->allocate();
449     _add_forget_cell.configure(compile_context, &_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
450     _mul_input_cell_res.allocator()->allocate();
451     _forget_gate.allocator()->allocate();
452     if(_has_cell_clipping)
453     {
454         _cell_clip.configure(compile_context, cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
455     }
456     // Output gate.
457     const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
458     const float      input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
459     configure_mm(compile_context, _mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
460                  input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
461                  &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
462                  mm_out_info, output_outstage_info);
463 
464     const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
465     configure_mm(compile_context, _mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
466                  output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
467                  &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
468                  mm_out_info, output_outstage_info);
469 
470     _accumulate_input_recurrent_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res,
471                                                  ConvertPolicy::SATURATE);
472     _input_to_output_outstage_res.allocator()->allocate();
473 
474     if(_has_peephole)
475     {
476         // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
477         // Here we are not using the output stage because all operations are done in float
478         _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
479         _memory_group.manage(&_mul_cell_to_output_res);
480         _pixelwise_mul_cell_to_output.configure(compile_context, cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
481 
482         const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
483         quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
484         _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
485         _memory_group.manage(&_cell_to_output_outstage_res);
486         _cell_to_output_outstage.configure(compile_context, &_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
487         _mul_cell_to_output_res.allocator()->allocate();
488 
489         _accumulate_cell_to_output.configure(compile_context, &_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res,
490                                              ConvertPolicy::SATURATE);
491         _cell_to_output_outstage_res.allocator()->allocate();
492     }
493 
494     CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
495 
496     if(_has_layer_norm)
497     {
498         configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
499         _recurrent_to_output_outstage_res.allocator()->allocate();
500         output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
501     }
502 
503     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
504     _memory_group.manage(&_output_gate);
505     _output_gate.allocator()->init(output_gate_info);
506     _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
507     output_activation_input->allocator()->allocate();
508 
509     // Hidden.
510     _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
511     // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplication
512     _memory_group.manage(&_hidden_mul_res);
513     const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
514     _hidden_mul_res.allocator()->init(hidden_mul_res);
515     _pixelwise_mul_hidden.configure(compile_context, &_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
516     _output_gate.allocator()->allocate();
517     _input_gate.allocator()->allocate();
518     const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
519     quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
520     gemmlowp_info.gemmlowp_offset  = lstm_params.hidden_state_zero();
521     gemmlowp_info.output_data_type = output_state_in->info()->data_type();
522 
523     _projection_tensor_copy_required = (num_units != output_size);
524     ICLTensor *hidden_gate_result    = output_state_out;
525 
526     _memory_group.manage(&_hidden_gate);
527 
528     if(_projection_tensor_copy_required)
529     {
530         _hidden_gate.allocator()->init(*output_state_out->info());
531         _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
532         hidden_gate_result = &_hidden_gate;
533     }
534 
535     _hidden_outstage.configure(compile_context, &_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
536     _hidden_mul_res.allocator()->allocate();
537 
538     // Projection.
539     if(_has_projection)
540     {
541         const TensorInfo              projection_outstage_info(*output_state_out->info());
542         const UniformQuantizationInfo qprojection      = _projection_weights->info()->quantization_info().uniform();
543         const float                   projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
544         gemmlowp_info.gemmlowp_offset                  = qoutput_state_in.offset;
545         gemmlowp_info.gemmlowp_min_bound               = std::numeric_limits<int8_t>::lowest();
546         gemmlowp_info.gemmlowp_max_bound               = std::numeric_limits<int8_t>::max();
547         gemmlowp_info.output_data_type                 = DataType::QASYMM8_SIGNED;
548 
549         TensorInfo projection_mm_out_info{ mm_out_info };
550         projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
551 
552         configure_mm(compile_context, _mm_projection, _projection_outstage, gemmlowp_info,
553                      hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
554                      &_mm_projection_res, &_projection_outstage_res, projection_scale,
555                      projection_mm_out_info, projection_outstage_info);
556 
557         ICLTensor *accumulate_destination = output_state_out;
558 
559         if(_projection_tensor_copy_required)
560         {
561             _hidden_gate.allocator()->allocate();
562             _projection_accumulate_res.allocator()->init(*output_state_in->info());
563             _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
564             _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
565             accumulate_destination = &_projection_accumulate_res;
566         }
567 
568         _accumulate_projection.configure(compile_context, &_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
569         _projection_outstage_res.allocator()->allocate();
570 
571         if(_projection_tensor_copy_required)
572         {
573             _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
574             _projection_accumulate_res.allocator()->allocate();
575         }
576 
577         int8_t quantized_projection_clip{ 0 };
578         if(lstm_params.projection_clip() > 0.0f)
579         {
580             quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
581         }
582 
583         if(quantized_projection_clip > 0)
584         {
585             _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
586                                                                                                        quantized_projection_clip));
587             _has_projection_clipping = true;
588         }
589     }
590     else
591     {
592         if(_projection_tensor_copy_required)
593         {
594             _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
595             _hidden_gate.allocator()->allocate();
596         }
597     }
598 
599     // Copy output_state_out to output
600     _copy_output.configure(compile_context, output_state_out, output);
601 }
602 
validate(const ITensorInfo * input,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * cell_state_in,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_out,const ITensorInfo * output_state_out,const ITensorInfo * output,const LSTMParams<ITensorInfo> & lstm_params)603 Status CLQLSTMLayer::validate(const ITensorInfo *input,
604                               const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
605                               const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
606                               const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
607                               const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
608                               const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
609                               const LSTMParams<ITensorInfo> &lstm_params)
610 {
611     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
612                                         recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
613                                         cell_state_out, output_state_out, output);
614 
615     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
616     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
617 
618     const unsigned int input_size  = input->dimension(0);
619     const unsigned int batch_size  = input->dimension(1);
620     const unsigned int num_units   = input_to_output_weights->dimension(1);
621     const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
622 
623     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
624     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
625     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
626     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
627     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
628     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
629     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QSYMM8);
630     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
631                                                        recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
632 
633     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
634     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
635     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
636     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
637     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
638 
639     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
640     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
641     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
642     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
643 
644     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
645     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
646     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
647     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
648 
649     // Check whether peephole weights are all there or none
650     if(lstm_params.has_peephole_opt())
651     {
652         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
653         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
654         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
655         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
656         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
657         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
658 
659         if(!lstm_params.has_cifg_opt())
660         {
661             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
662             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
663             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
664         }
665     }
666 
667     const UniformQuantizationInfo qinput           = input->quantization_info().uniform();
668     const UniformQuantizationInfo qcell_state_in   = cell_state_in->quantization_info().uniform();
669     const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
670 
671     // Calculate and decompose effective scales for optimizing matmul calculation
672     const int32_t cell_shift = log2(qcell_state_in.scale);
673     ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
674 
675     // Calculate quantized parameters for clipping.
676     int16_t quantized_cell_clip = 0;
677     if(lstm_params.cell_clip() > 0.0f)
678     {
679         quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
680     }
681 
682     // Precompute effective bias for optimizing the matmul computations.
683     const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
684     const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
685     if(!lstm_params.has_cifg_opt())
686     {
687         ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
688         ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
689                                                                                true)));
690     }
691     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
692     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
693     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
694     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
695     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
696     ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true)));
697     if(lstm_params.has_projection())
698     {
699         ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
700                                                                                lstm_params.hidden_state_zero(),
701                                                                                true)));
702         if(lstm_params.projection_bias() != nullptr)
703         {
704             ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
705             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info,
706                                                                        &projection_eff_bias_info, ConvertPolicy::SATURATE));
707         }
708     }
709 
710     const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info());
711     const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
712 
713     // Validate weights transpose
714     ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_forget_weights, &input_weights_transposed));
715     ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_cell_weights, &input_weights_transposed));
716     ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(input_to_output_weights, &input_weights_transposed));
717     ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed));
718     ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed));
719     ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed));
720     if(!lstm_params.has_cifg_opt())
721     {
722         ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed));
723         ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed));
724     }
725     if(lstm_params.has_projection())
726     {
727         const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
728         ARM_COMPUTE_RETURN_ON_ERROR(CLTranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
729     }
730 
731     GEMMLowpOutputStageInfo gemmlowp_info;
732     gemmlowp_info.type               = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
733     gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
734     gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
735     gemmlowp_info.output_data_type   = DataType::QSYMM16;
736 
737     const bool has_layer_norm = lstm_params.use_layer_norm();
738 
739     // Forget gate.
740     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
741     const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
742     const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
743     const float      input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
744     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
745 
746     const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
747     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
748 
749     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
750 
751     if(lstm_params.has_peephole_opt())
752     {
753         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
754         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
755                                                                         RoundingPolicy::TO_ZERO));
756         const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
757         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
758         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
759         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
760     }
761 
762     if(has_layer_norm)
763     {
764         const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
765         const ITensorInfo *b_info = forget_gate_bias;
766         ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
767     }
768 
769     // Output quantization info of Sigmoid and Tanh activations
770     const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
771 
772     const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
773     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
774 
775     // Modulation gate.
776     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
777     const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
778     const float      input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
779     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
780 
781     const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
782     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
783 
784     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
785 
786     if(has_layer_norm)
787     {
788         const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
789         const ITensorInfo *b_info = cell_bias;
790         ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
791     }
792 
793     const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
794     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
795 
796     // Input gate.
797     const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
798     if(lstm_params.has_cifg_opt())
799     {
800         ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
801         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
802     }
803     else
804     {
805         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
806         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
807         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
808         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
809         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
810         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
811 
812         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
813         const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
814         const float      input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
815         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
816 
817         const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
818         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
819 
820         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
821 
822         if(lstm_params.has_peephole_opt())
823         {
824             ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
825                                                                             RoundingPolicy::TO_ZERO));
826             const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
827             ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
828             ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
829             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
830         }
831 
832         if(has_layer_norm)
833         {
834             const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
835             const ITensorInfo *b_info = lstm_params.input_gate_bias();
836             ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
837         }
838 
839         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC, 1.f, 1.f)));
840     }
841     // Cell.
842     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
843     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
844     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
845     if(quantized_cell_clip > 0)
846     {
847         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
848                                                                                                              quantized_cell_clip)));
849     }
850     // Output gate.
851     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
852     const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
853     const float      input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
854     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
855 
856     const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
857     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
858 
859     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
860     if(lstm_params.has_peephole_opt())
861     {
862         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
863         // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel
864         // Here we are not using the output stage because all operations are done in float
865         // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
866         // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
867         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
868                                                                         RoundingPolicy::TO_ZERO));
869         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
870     }
871 
872     if(has_layer_norm)
873     {
874         const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
875         const ITensorInfo *b_info = output_gate_bias;
876         ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
877     }
878 
879     const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
880     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
881 
882     // Hidden.
883     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
884     const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
885     const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
886 
887     ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
888     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
889     const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
890     ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
891     gemmlowp_info.gemmlowp_offset  = lstm_params.hidden_state_zero();
892     gemmlowp_info.output_data_type = hidden_out_info.data_type();
893     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
894 
895     const bool projection_tensor_copy_required = num_units != output_size;
896 
897     // Projection.
898     if(lstm_params.has_projection())
899     {
900         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
901         ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
902 
903         const UniformQuantizationInfo qprojection      = lstm_params.projection_weights()->quantization_info().uniform();
904         const float                   projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
905         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
906         gemmlowp_info.gemmlowp_offset    = qoutput_state_in.offset;
907         gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
908         gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
909         gemmlowp_info.output_data_type   = DataType::QASYMM8_SIGNED;
910 
911         const TensorInfo projection_outstage_info(*output_state_out);
912         const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
913 
914         TensorInfo projection_mm_out_info{ mm_out_info };
915         projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
916 
917         ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
918                                                 &projection_outstage_info));
919 
920         if(projection_tensor_copy_required)
921         {
922             ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
923         }
924 
925         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
926 
927         if(projection_tensor_copy_required)
928         {
929             ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
930         }
931 
932         int8_t quantized_projection_clip{ 0 };
933         if(lstm_params.projection_clip() > 0.0f)
934         {
935             quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
936         }
937 
938         if(quantized_projection_clip > 0)
939         {
940             ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
941                                                                                                                    quantized_projection_clip)));
942         }
943     }
944     else
945     {
946         if(projection_tensor_copy_required)
947         {
948             ARM_COMPUTE_RETURN_ON_ERROR(CLQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
949         }
950     }
951 
952     if(cell_state_out->total_size() > 0)
953     {
954         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
955         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
956     }
957 
958     if(output_state_out->total_size() > 0)
959     {
960         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
961         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
962     }
963 
964     ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(output_state_out, output));
965     return Status{};
966 }
967 
run()968 void CLQLSTMLayer::run()
969 {
970     prepare();
971 
972     // Acquire all the temporaries
973     MemoryGroupResourceScope scope_mg(_memory_group);
974 
975     // Forget gate.
976     _mm_input_to_forget.run();
977     _input_to_forget_outstage.run();
978 
979     _mm_recurrent_to_forget.run();
980     _recurrent_to_forget_outstage.run();
981     _accumulate_input_recurrent_forget.run();
982 
983     if(_has_peephole)
984     {
985         _pixelwise_mul_cell_to_forget.run();
986         _cell_to_forget_outstage.run();
987         _accumulate_cell_forget.run();
988     }
989 
990     if(_has_layer_norm)
991     {
992         CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
993     }
994 
995     _forget_gate_sigmoid.run();
996 
997     // Modulation gate.
998     _mm_input_to_cell.run();
999     _input_to_cell_outstage.run();
1000 
1001     _mm_recurrent_to_cell.run();
1002     _recurrent_to_cell_outstage.run();
1003     _accumulate_input_recurrent_modulation.run();
1004 
1005     if(_has_layer_norm)
1006     {
1007         CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
1008     }
1009 
1010     _cell_gate_tanh.run();
1011 
1012     // Input gate
1013     if(_has_cifg)
1014     {
1015         _input_gate_sub.run();
1016     }
1017     else
1018     {
1019         _mm_input_to_input.run();
1020         _input_to_input_outstage.run();
1021         _mm_recurrent_to_input.run();
1022         _recurrent_to_input_outstage.run();
1023         _accumulate_input_recurrent_input.run();
1024 
1025         if(_has_peephole)
1026         {
1027             _pixelwise_mul_cell_to_input.run();
1028             _cell_to_input_outstage.run();
1029             _accumulate_cell_input.run();
1030         }
1031 
1032         if(_has_layer_norm)
1033         {
1034             CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
1035         }
1036 
1037         _input_gate_sigmoid.run();
1038     }
1039 
1040     // Cell.
1041     _pixelwise_mul_forget_cell.run();
1042     _pixelwise_mul_input_cell.run();
1043     _add_forget_cell.run();
1044     if(_has_cell_clipping)
1045     {
1046         _cell_clip.run();
1047     }
1048 
1049     // Output gate.
1050     _mm_input_to_output.run();
1051     _input_to_output_outstage.run();
1052     _mm_recurrent_to_output.run();
1053     _recurrent_to_output_outstage.run();
1054     _accumulate_input_recurrent_output.run();
1055     if(_has_peephole)
1056     {
1057         _pixelwise_mul_cell_to_output.run();
1058         _cell_to_output_outstage.run();
1059         _accumulate_cell_to_output.run();
1060     }
1061 
1062     if(_has_layer_norm)
1063     {
1064         CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
1065     }
1066 
1067     _output_gate_sigmoid.run();
1068 
1069     // Hidden.
1070     _hidden_tanh.run();
1071     _pixelwise_mul_hidden.run();
1072     _hidden_outstage.run();
1073 
1074     // Projection.
1075     if(_has_projection)
1076     {
1077         _mm_projection.run();
1078         _projection_outstage.run();
1079 
1080         if(_projection_tensor_copy_required)
1081         {
1082             _projection_output_to_accumulate_copy.run();
1083         }
1084 
1085         _accumulate_projection.run();
1086 
1087         if(_projection_tensor_copy_required)
1088         {
1089             _projection_accumulate_to_output_copy.run();
1090         }
1091 
1092         if(_has_projection_clipping)
1093         {
1094             _projection_clip.run();
1095         }
1096     }
1097     else
1098     {
1099         if(_projection_tensor_copy_required)
1100         {
1101             _hidden_to_output_copy.run();
1102         }
1103     }
1104 
1105     // Copy output_state_out to output
1106     _copy_output.run();
1107 }
1108 
prepare()1109 void CLQLSTMLayer::prepare()
1110 {
1111     if(!_is_prepared)
1112     {
1113         // Pre-transpose weights to be used in GEMM.
1114         _input_to_forget_weights_transposed.allocator()->allocate();
1115         _input_to_cell_weights_transposed.allocator()->allocate();
1116         _input_to_output_weights_transposed.allocator()->allocate();
1117         _recurrent_to_forget_weights_transposed.allocator()->allocate();
1118         _recurrent_to_cell_weights_transposed.allocator()->allocate();
1119         _recurrent_to_output_weights_transposed.allocator()->allocate();
1120         _transpose_input_to_forget_weights.run();
1121         _transpose_input_to_cell_weights.run();
1122         _transpose_input_to_output_weights.run();
1123         _transpose_recurrent_to_forget_weights.run();
1124         _transpose_recurrent_to_cell_weights.run();
1125         _transpose_recurrent_to_output_weights.run();
1126 
1127         // Precompute effective biases
1128         if(_has_cifg)
1129         {
1130             _ones.map(true);
1131             std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1132             _ones.unmap();
1133         }
1134         else
1135         {
1136             _input_to_input_eff_bias.allocator()->allocate();
1137             _recurrent_to_input_eff_bias.allocator()->allocate();
1138 
1139             ITensorPack input_to_input_red_pack = { { ACL_SRC, _input_to_input_weights }, { ACL_DST, &_input_to_input_eff_bias } };
1140             CLScheduler::get().enqueue_op(*_input_to_input_reduction, input_to_input_red_pack, false);
1141 
1142             ITensorPack rec_to_input_red_pack = { { ACL_SRC, _recurrent_to_input_weights }, { ACL_DST, &_recurrent_to_input_eff_bias } };
1143             CLScheduler::get().enqueue_op(*_recurrent_to_input_reduction, rec_to_input_red_pack, false);
1144 
1145             _input_to_input_weights_transposed.allocator()->allocate();
1146             _recurrent_to_input_weights_transposed.allocator()->allocate();
1147             _transpose_input_to_input_weights.run();
1148             _transpose_recurrent_to_input_weights.run();
1149             _input_to_input_weights->mark_as_unused();
1150             _recurrent_to_input_weights->mark_as_unused();
1151         }
1152         _input_to_forget_eff_bias.allocator()->allocate();
1153         _recurrent_to_forget_eff_bias.allocator()->allocate();
1154         _input_to_cell_eff_bias.allocator()->allocate();
1155         _recurrent_to_cell_eff_bias.allocator()->allocate();
1156         _input_to_output_eff_bias.allocator()->allocate();
1157         _recurrent_to_output_eff_bias.allocator()->allocate();
1158 
1159         ITensorPack input_to_forget_red_pack = { { ACL_SRC, _input_to_forget_weights }, { ACL_DST, &_input_to_forget_eff_bias } };
1160         CLScheduler::get().enqueue_op(*_input_to_forget_reduction, input_to_forget_red_pack, false);
1161 
1162         ITensorPack rec_to_forget_red_pack = { { ACL_SRC, _recurrent_to_forget_weights }, { ACL_DST, &_recurrent_to_forget_eff_bias } };
1163         CLScheduler::get().enqueue_op(*_recurrent_to_forget_reduction, rec_to_forget_red_pack, false);
1164 
1165         ITensorPack input_to_cell_red_pack = { { ACL_SRC, _input_to_cell_weights }, { ACL_DST, &_input_to_cell_eff_bias } };
1166         CLScheduler::get().enqueue_op(*_input_to_cell_reduction, input_to_cell_red_pack, false);
1167 
1168         ITensorPack rec_to_cell_red_pack = { { ACL_SRC, _recurrent_to_cell_weights }, { ACL_DST, &_recurrent_to_cell_eff_bias } };
1169         CLScheduler::get().enqueue_op(*_recurrent_to_cell_reduction, rec_to_cell_red_pack, false);
1170 
1171         ITensorPack input_to_output_red_pack = { { ACL_SRC, _input_to_output_weights }, { ACL_DST, &_input_to_output_eff_bias } };
1172         CLScheduler::get().enqueue_op(*_input_to_output_reduction, input_to_output_red_pack, false);
1173 
1174         ITensorPack rec_to_output_red_pack = { { ACL_SRC, _recurrent_to_output_weights }, { ACL_DST, &_recurrent_to_output_eff_bias } };
1175         CLScheduler::get().enqueue_op(*_recurrent_to_output_reduction, rec_to_output_red_pack, false);
1176 
1177         if(_has_projection)
1178         {
1179             _projection_eff_bias.allocator()->allocate();
1180             ITensorPack proj_red_pack{ { ACL_SRC, _projection_weights }, { ACL_DST, &_projection_eff_bias } };
1181             CLScheduler::get().enqueue_op(*_projection_reduction, proj_red_pack, false);
1182             if(_projection_bias != nullptr)
1183             {
1184                 _projection_bias_add.run();
1185                 _projection_bias->mark_as_unused();
1186             }
1187 
1188             _projection_weights_transposed.allocator()->allocate();
1189             _transpose_projection_weights.run();
1190             _projection_weights->mark_as_unused();
1191 
1192             if(!_projection_tensor_copy_required)
1193             {
1194                 _hidden_gate.mark_as_unused();
1195                 _projection_accumulate_res.mark_as_unused();
1196             }
1197         }
1198 
1199         // Mark weights as unused
1200         _input_to_forget_weights->mark_as_unused();
1201         _input_to_cell_weights->mark_as_unused();
1202         _input_to_output_weights->mark_as_unused();
1203         _recurrent_to_forget_weights->mark_as_unused();
1204         _recurrent_to_cell_weights->mark_as_unused();
1205         _recurrent_to_output_weights->mark_as_unused();
1206 
1207         CLScheduler::get().queue().finish();
1208         _is_prepared = true;
1209     }
1210 }
1211 
1212 } // namespace arm_compute
1213