1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2020-2022 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "arm_compute/runtime/NEON/functions/NEQLSTMLayer.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/core/ITensorPack.h"
27*c217d954SCole Faust #include "arm_compute/core/KernelDescriptors.h"
28*c217d954SCole Faust #include "arm_compute/core/QuantizationInfo.h"
29*c217d954SCole Faust #include "arm_compute/core/Utils.h"
30*c217d954SCole Faust #include "arm_compute/core/Validate.h"
31*c217d954SCole Faust #include "arm_compute/core/utils/misc/InfoHelpers.h"
32*c217d954SCole Faust #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
33*c217d954SCole Faust #include "arm_compute/runtime/NEON/NEScheduler.h"
34*c217d954SCole Faust #include "src/common/utils/Log.h"
35*c217d954SCole Faust #include "src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h"
36*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
37*c217d954SCole Faust #include "src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.h"
38*c217d954SCole Faust
39*c217d954SCole Faust namespace arm_compute
40*c217d954SCole Faust {
41*c217d954SCole Faust using namespace arm_compute::utils::info_helpers;
42*c217d954SCole Faust namespace
43*c217d954SCole Faust {
validate_mm(GEMMLowpOutputStageInfo & gemmlowp_info,const ITensorInfo * mm_input,const ITensorInfo * mm_weights,const ITensorInfo * bias,float gemmlowp_scale,const TensorInfo * mm_res_info,const TensorInfo * outstage_tensor_info)44*c217d954SCole Faust Status validate_mm(GEMMLowpOutputStageInfo &gemmlowp_info, const ITensorInfo *mm_input, const ITensorInfo *mm_weights, const ITensorInfo *bias,
45*c217d954SCole Faust float gemmlowp_scale, const TensorInfo *mm_res_info, const TensorInfo *outstage_tensor_info)
46*c217d954SCole Faust {
47*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(mm_input, mm_weights, nullptr, mm_res_info));
48*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
49*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(mm_res_info, bias, outstage_tensor_info, gemmlowp_info));
50*c217d954SCole Faust return Status{};
51*c217d954SCole Faust }
52*c217d954SCole Faust } // namespace
53*c217d954SCole Faust
validate_layer_norm(const ITensorInfo & in,const ITensorInfo & weight,const ITensorInfo & bias)54*c217d954SCole Faust Status NEQLSTMLayer::validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
55*c217d954SCole Faust {
56*c217d954SCole Faust // Output quantization scale will be different, but ignored here
57*c217d954SCole Faust // since it will be configured at configure() stage.
58*c217d954SCole Faust const TensorInfo out
59*c217d954SCole Faust {
60*c217d954SCole Faust in
61*c217d954SCole Faust };
62*c217d954SCole Faust return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
63*c217d954SCole Faust }
64*c217d954SCole Faust
configure_layer_norm(NEQLSTMLayer::LayerNormGate g,const ITensor * in)65*c217d954SCole Faust void NEQLSTMLayer::configure_layer_norm(NEQLSTMLayer::LayerNormGate g, const ITensor *in)
66*c217d954SCole Faust {
67*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
68*c217d954SCole Faust
69*c217d954SCole Faust Tensor &out = get_layer_norm_output(g);
70*c217d954SCole Faust _memory_group.manage(&out);
71*c217d954SCole Faust out.allocator()->init(*(in->info()));
72*c217d954SCole Faust
73*c217d954SCole Faust get_layer_norm(g) = std::make_unique<NEQLSTMLayerNormalizationKernel>();
74*c217d954SCole Faust get_layer_norm(g)->configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g));
75*c217d954SCole Faust }
76*c217d954SCole Faust
77*c217d954SCole Faust NEQLSTMLayer::TensorCopyKernel::~TensorCopyKernel() = default;
78*c217d954SCole Faust
validate(const ITensorInfo & src,const ITensorInfo & dst)79*c217d954SCole Faust Status NEQLSTMLayer::TensorCopyKernel::validate(const ITensorInfo &src, const ITensorInfo &dst)
80*c217d954SCole Faust {
81*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(src.tensor_shape().num_dimensions() > max_dimension_supported);
82*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().num_dimensions() > max_dimension_supported);
83*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst);
84*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(dst.tensor_shape().y() != src.tensor_shape().y());
85*c217d954SCole Faust return Status{};
86*c217d954SCole Faust }
87*c217d954SCole Faust
configure(ITensor & src,ITensor & dst)88*c217d954SCole Faust void NEQLSTMLayer::TensorCopyKernel::configure(ITensor &src, ITensor &dst)
89*c217d954SCole Faust {
90*c217d954SCole Faust ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::TensorCopyKernel::validate(*src.info(), *dst.info()));
91*c217d954SCole Faust ARM_COMPUTE_LOG_PARAMS(src, dst);
92*c217d954SCole Faust
93*c217d954SCole Faust _src = &src;
94*c217d954SCole Faust _dst = &dst;
95*c217d954SCole Faust _row_size = std::min(_src->info()->tensor_shape().x(), _dst->info()->tensor_shape().x());
96*c217d954SCole Faust _window = calculate_max_window(*_src->info(), Steps());
97*c217d954SCole Faust }
98*c217d954SCole Faust
run()99*c217d954SCole Faust void NEQLSTMLayer::TensorCopyKernel::run()
100*c217d954SCole Faust {
101*c217d954SCole Faust Iterator input_iter{ _src, _window };
102*c217d954SCole Faust Iterator output_iter{ _dst, _window };
103*c217d954SCole Faust
104*c217d954SCole Faust execute_window_loop(_window, [&](const Coordinates &)
105*c217d954SCole Faust {
106*c217d954SCole Faust memcpy(output_iter.ptr(), input_iter.ptr(), _row_size);
107*c217d954SCole Faust },
108*c217d954SCole Faust input_iter, output_iter);
109*c217d954SCole Faust }
110*c217d954SCole Faust
111*c217d954SCole Faust NEQLSTMLayer::~NEQLSTMLayer() = default;
112*c217d954SCole Faust
NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)113*c217d954SCole Faust NEQLSTMLayer::NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
114*c217d954SCole Faust : _memory_group(),
115*c217d954SCole Faust _dequantize_input_to_forget_weights(),
116*c217d954SCole Faust _quantize_input_to_forget_weights(),
117*c217d954SCole Faust _transpose_input_to_forget_weights(),
118*c217d954SCole Faust _transpose_input_to_cell_weights(),
119*c217d954SCole Faust _transpose_input_to_output_weights(),
120*c217d954SCole Faust _transpose_input_to_input_weights(),
121*c217d954SCole Faust _transpose_recurrent_to_forget_weights(),
122*c217d954SCole Faust _transpose_recurrent_to_cell_weights(),
123*c217d954SCole Faust _transpose_recurrent_to_output_weights(),
124*c217d954SCole Faust _transpose_recurrent_to_input_weights(),
125*c217d954SCole Faust _transpose_projection_weights(),
126*c217d954SCole Faust _input_to_input_reduction(),
127*c217d954SCole Faust _recurrent_to_input_reduction(),
128*c217d954SCole Faust _input_to_forget_reduction(),
129*c217d954SCole Faust _recurrent_to_forget_reduction(),
130*c217d954SCole Faust _input_to_cell_reduction(),
131*c217d954SCole Faust _recurrent_to_cell_reduction(),
132*c217d954SCole Faust _input_to_output_reduction(),
133*c217d954SCole Faust _recurrent_to_output_reduction(),
134*c217d954SCole Faust _projection_reduction(),
135*c217d954SCole Faust _projection_bias_add(),
136*c217d954SCole Faust _mm_input_to_forget(),
137*c217d954SCole Faust _mm_recurrent_to_forget(),
138*c217d954SCole Faust _pixelwise_mul_cell_to_forget(),
139*c217d954SCole Faust _input_to_forget_outstage(),
140*c217d954SCole Faust _recurrent_to_forget_outstage(),
141*c217d954SCole Faust _cell_to_forget_outstage(),
142*c217d954SCole Faust _accumulate_input_recurrent_forget(),
143*c217d954SCole Faust _accumulate_cell_forget(),
144*c217d954SCole Faust _forget_gate_sigmoid(),
145*c217d954SCole Faust _mm_input_to_cell(),
146*c217d954SCole Faust _input_to_cell_outstage(),
147*c217d954SCole Faust _mm_recurrent_to_cell(),
148*c217d954SCole Faust _recurrent_to_cell_outstage(),
149*c217d954SCole Faust _accumulate_input_recurrent_modulation(),
150*c217d954SCole Faust _cell_gate_tanh(),
151*c217d954SCole Faust _input_gate_sub(),
152*c217d954SCole Faust _mm_input_to_input(),
153*c217d954SCole Faust _input_to_input_outstage(),
154*c217d954SCole Faust _mm_recurrent_to_input(),
155*c217d954SCole Faust _recurrent_to_input_outstage(),
156*c217d954SCole Faust _accumulate_input_recurrent_input(),
157*c217d954SCole Faust _pixelwise_mul_cell_to_input(),
158*c217d954SCole Faust _cell_to_input_outstage(),
159*c217d954SCole Faust _accumulate_cell_input(),
160*c217d954SCole Faust _input_gate_sigmoid(),
161*c217d954SCole Faust _pixelwise_mul_forget_cell(),
162*c217d954SCole Faust _pixelwise_mul_input_cell(),
163*c217d954SCole Faust _add_forget_cell(),
164*c217d954SCole Faust _cell_clip(),
165*c217d954SCole Faust _mm_input_to_output(),
166*c217d954SCole Faust _input_to_output_outstage(),
167*c217d954SCole Faust _mm_recurrent_to_output(),
168*c217d954SCole Faust _recurrent_to_output_outstage(),
169*c217d954SCole Faust _accumulate_input_recurrent_output(),
170*c217d954SCole Faust _pixelwise_mul_cell_to_output(),
171*c217d954SCole Faust _cell_to_output_outstage(),
172*c217d954SCole Faust _accumulate_cell_to_output(),
173*c217d954SCole Faust _output_gate_sigmoid(),
174*c217d954SCole Faust _hidden_tanh(),
175*c217d954SCole Faust _pixelwise_mul_hidden(),
176*c217d954SCole Faust _hidden_outstage(),
177*c217d954SCole Faust _mm_projection(),
178*c217d954SCole Faust _projection_outstage(),
179*c217d954SCole Faust _accumulate_projection(),
180*c217d954SCole Faust _projection_clip(),
181*c217d954SCole Faust _projection_bias_copy(),
182*c217d954SCole Faust _projection_output_to_accumulate_copy(),
183*c217d954SCole Faust _projection_accumulate_to_output_copy(),
184*c217d954SCole Faust _hidden_to_output_copy(),
185*c217d954SCole Faust _layer_norms(),
186*c217d954SCole Faust _copy_output(),
187*c217d954SCole Faust _layer_norm_weights(),
188*c217d954SCole Faust _layer_norm_bias(),
189*c217d954SCole Faust _layer_norm_output()
190*c217d954SCole Faust {
191*c217d954SCole Faust _memory_group = MemoryGroup(std::move(memory_manager));
192*c217d954SCole Faust }
193*c217d954SCole Faust
configure_mm(NEGEMMLowpMatrixMultiplyCore & mm,NEGEMMLowpOutputStage & outstage,GEMMLowpOutputStageInfo & gemmlowp_info,const ITensor * mm_input,const ITensor * mm_weights,const ITensor * bias,Tensor * mm_res,Tensor * outstage_res,float gemmlowp_scale,const TensorInfo & mm_res_info,const TensorInfo & outstage_tensor_info)194*c217d954SCole Faust void NEQLSTMLayer::configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
195*c217d954SCole Faust const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias,
196*c217d954SCole Faust Tensor *mm_res, Tensor *outstage_res, float gemmlowp_scale,
197*c217d954SCole Faust const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info)
198*c217d954SCole Faust {
199*c217d954SCole Faust _memory_group.manage(mm_res);
200*c217d954SCole Faust _memory_group.manage(outstage_res);
201*c217d954SCole Faust
202*c217d954SCole Faust mm_res->allocator()->init(mm_res_info);
203*c217d954SCole Faust outstage_res->allocator()->init(outstage_tensor_info);
204*c217d954SCole Faust
205*c217d954SCole Faust // Configure matrix-multiplication
206*c217d954SCole Faust mm.configure(mm_input, mm_weights, nullptr, mm_res);
207*c217d954SCole Faust
208*c217d954SCole Faust // Configure output stage
209*c217d954SCole Faust quantization::calculate_quantized_multiplier(gemmlowp_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
210*c217d954SCole Faust outstage.configure(mm_res, bias, outstage_res, gemmlowp_info);
211*c217d954SCole Faust mm_res->allocator()->allocate();
212*c217d954SCole Faust }
213*c217d954SCole Faust
configure(const ITensor * input,const ITensor * input_to_forget_weights,const ITensor * input_to_cell_weights,const ITensor * input_to_output_weights,const ITensor * recurrent_to_forget_weights,const ITensor * recurrent_to_cell_weights,const ITensor * recurrent_to_output_weights,const ITensor * forget_gate_bias,const ITensor * cell_bias,const ITensor * output_gate_bias,const ITensor * cell_state_in,ITensor * output_state_in,ITensor * cell_state_out,ITensor * output_state_out,ITensor * output,const LSTMParams<ITensor> & lstm_params)214*c217d954SCole Faust void NEQLSTMLayer::configure(const ITensor *input,
215*c217d954SCole Faust const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
216*c217d954SCole Faust const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
217*c217d954SCole Faust const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
218*c217d954SCole Faust const ITensor *cell_state_in, ITensor *output_state_in,
219*c217d954SCole Faust ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
220*c217d954SCole Faust const LSTMParams<ITensor> &lstm_params)
221*c217d954SCole Faust {
222*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
223*c217d954SCole Faust recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
224*c217d954SCole Faust forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
225*c217d954SCole Faust
226*c217d954SCole Faust ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
227*c217d954SCole Faust recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
228*c217d954SCole Faust forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
229*c217d954SCole Faust
230*c217d954SCole Faust // Set lstm parameters
231*c217d954SCole Faust LSTMParams<ITensorInfo> lstm_params_info{};
232*c217d954SCole Faust build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
233*c217d954SCole Faust
234*c217d954SCole Faust _input_to_forget_weights_transposed.info()->set_quantization_info(input_to_forget_weights->info()->quantization_info());
235*c217d954SCole Faust _input_to_cell_weights_transposed.info()->set_quantization_info(input_to_cell_weights->info()->quantization_info());
236*c217d954SCole Faust _input_to_output_weights_transposed.info()->set_quantization_info(input_to_output_weights->info()->quantization_info());
237*c217d954SCole Faust _recurrent_to_forget_weights_transposed.info()->set_quantization_info(recurrent_to_forget_weights->info()->quantization_info());
238*c217d954SCole Faust _recurrent_to_cell_weights_transposed.info()->set_quantization_info(recurrent_to_cell_weights->info()->quantization_info());
239*c217d954SCole Faust _recurrent_to_output_weights_transposed.info()->set_quantization_info(recurrent_to_output_weights->info()->quantization_info());
240*c217d954SCole Faust
241*c217d954SCole Faust if(input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED)
242*c217d954SCole Faust {
243*c217d954SCole Faust _convert_input_to_forget_weights_to_qsymm8 = true;
244*c217d954SCole Faust // Setup dequantize output tensor to go from QASYMM8_SIGNED -> F32
245*c217d954SCole Faust
246*c217d954SCole Faust _input_to_forget_weights_f32.allocator()->init(TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::F32)
247*c217d954SCole Faust .set_data_layout(input_to_forget_weights->info()->data_layout()));
248*c217d954SCole Faust // Setup the quantize output tensor to go from F32 -> QSYMM8
249*c217d954SCole Faust _input_to_forget_weights_symm8.allocator()->init((TensorInfo(input_to_forget_weights->info()->tensor_shape(), 1, DataType::QSYMM8)
250*c217d954SCole Faust .set_data_layout(input_to_forget_weights->info()->data_layout())
251*c217d954SCole Faust .set_quantization_info(input_to_forget_weights->info()->quantization_info())));
252*c217d954SCole Faust
253*c217d954SCole Faust _dequantize_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_f32);
254*c217d954SCole Faust _quantize_input_to_forget_weights.configure(&_input_to_forget_weights_f32, &_input_to_forget_weights_symm8);
255*c217d954SCole Faust
256*c217d954SCole Faust ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), _input_to_forget_weights_symm8.info(), input_to_cell_weights->info(), input_to_output_weights->info(),
257*c217d954SCole Faust recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
258*c217d954SCole Faust forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
259*c217d954SCole Faust cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
260*c217d954SCole Faust lstm_params_info));
261*c217d954SCole Faust }
262*c217d954SCole Faust else
263*c217d954SCole Faust {
264*c217d954SCole Faust ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
265*c217d954SCole Faust recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
266*c217d954SCole Faust forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
267*c217d954SCole Faust cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
268*c217d954SCole Faust lstm_params_info));
269*c217d954SCole Faust }
270*c217d954SCole Faust
271*c217d954SCole Faust const int batch_size = input->info()->dimension(1);
272*c217d954SCole Faust const int num_units = input_to_output_weights->info()->dimension(1);
273*c217d954SCole Faust const int output_size = output_state_out->info()->dimension(_out_state_output_size_dimension_idx);
274*c217d954SCole Faust
275*c217d954SCole Faust const UniformQuantizationInfo qinput = input->info()->quantization_info().uniform();
276*c217d954SCole Faust const UniformQuantizationInfo qcell_state_in = cell_state_in->info()->quantization_info().uniform();
277*c217d954SCole Faust const UniformQuantizationInfo qoutput_state_in = output_state_in->info()->quantization_info().uniform();
278*c217d954SCole Faust
279*c217d954SCole Faust _projection_bias = lstm_params.projection_bias();
280*c217d954SCole Faust _input_to_forget_weights = (input_to_forget_weights->info()->data_type() == DataType::QASYMM8_SIGNED) ? &_input_to_forget_weights_symm8 : input_to_forget_weights;
281*c217d954SCole Faust _input_to_cell_weights = input_to_cell_weights;
282*c217d954SCole Faust _input_to_output_weights = input_to_output_weights;
283*c217d954SCole Faust _recurrent_to_forget_weights = recurrent_to_forget_weights;
284*c217d954SCole Faust _recurrent_to_cell_weights = recurrent_to_cell_weights;
285*c217d954SCole Faust _recurrent_to_output_weights = recurrent_to_output_weights;
286*c217d954SCole Faust _projection_weights = lstm_params.projection_weights();
287*c217d954SCole Faust
288*c217d954SCole Faust // Layer normalization
289*c217d954SCole Faust _has_layer_norm = lstm_params.use_layer_norm();
290*c217d954SCole Faust if(_has_layer_norm)
291*c217d954SCole Faust {
292*c217d954SCole Faust set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
293*c217d954SCole Faust set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
294*c217d954SCole Faust set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
295*c217d954SCole Faust set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
296*c217d954SCole Faust
297*c217d954SCole Faust set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
298*c217d954SCole Faust set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
299*c217d954SCole Faust set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
300*c217d954SCole Faust set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
301*c217d954SCole Faust }
302*c217d954SCole Faust
303*c217d954SCole Faust _has_cifg = lstm_params.has_cifg_opt();
304*c217d954SCole Faust _has_projection = lstm_params.has_projection();
305*c217d954SCole Faust _has_peephole = lstm_params.has_peephole_opt();
306*c217d954SCole Faust
307*c217d954SCole Faust // Calculate and decompose effective scales for optimizing matmul calculation
308*c217d954SCole Faust const int32_t cell_shift = log2(qcell_state_in.scale);
309*c217d954SCole Faust
310*c217d954SCole Faust // Calculate quantized parameters for clipping.
311*c217d954SCole Faust int16_t quantized_cell_clip = 0;
312*c217d954SCole Faust if(lstm_params.cell_clip() > 0.0f)
313*c217d954SCole Faust {
314*c217d954SCole Faust quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
315*c217d954SCole Faust }
316*c217d954SCole Faust _has_cell_clipping = quantized_cell_clip > 0;
317*c217d954SCole Faust
318*c217d954SCole Faust // Precompute effective bias for optimizing the matmul computations.
319*c217d954SCole Faust if(!_has_cifg)
320*c217d954SCole Faust {
321*c217d954SCole Faust _input_to_input_weights = lstm_params.input_to_input_weights();
322*c217d954SCole Faust _recurrent_to_input_weights = lstm_params.recurrent_to_input_weights();
323*c217d954SCole Faust
324*c217d954SCole Faust _input_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
325*c217d954SCole Faust _recurrent_to_input_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
326*c217d954SCole Faust _input_to_input_reduction->configure(_input_to_input_weights->info(), _input_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
327*c217d954SCole Faust _recurrent_to_input_reduction->configure(_recurrent_to_input_weights->info(), _recurrent_to_input_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
328*c217d954SCole Faust }
329*c217d954SCole Faust
330*c217d954SCole Faust _input_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
331*c217d954SCole Faust _recurrent_to_forget_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
332*c217d954SCole Faust _input_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
333*c217d954SCole Faust _recurrent_to_cell_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
334*c217d954SCole Faust _input_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
335*c217d954SCole Faust _recurrent_to_output_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
336*c217d954SCole Faust
337*c217d954SCole Faust _input_to_forget_reduction->configure(input_to_forget_weights->info(), _input_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
338*c217d954SCole Faust _recurrent_to_forget_reduction->configure(recurrent_to_forget_weights->info(), _recurrent_to_forget_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
339*c217d954SCole Faust _input_to_cell_reduction->configure(input_to_cell_weights->info(), _input_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
340*c217d954SCole Faust _recurrent_to_cell_reduction->configure(recurrent_to_cell_weights->info(), _recurrent_to_cell_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
341*c217d954SCole Faust _input_to_output_reduction->configure(input_to_output_weights->info(), _input_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true));
342*c217d954SCole Faust _recurrent_to_output_reduction->configure(recurrent_to_output_weights->info(), _recurrent_to_output_eff_bias.info(), GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset, true));
343*c217d954SCole Faust if(_has_projection)
344*c217d954SCole Faust {
345*c217d954SCole Faust _projection_reduction = std::make_unique<cpu::kernels::CpuGemmLowpMatrixAReductionKernel>();
346*c217d954SCole Faust _projection_reduction->configure(_projection_weights->info(), _projection_eff_bias.info(), GEMMLowpReductionKernelInfo(output_size, false, lstm_params.hidden_state_zero(), true));
347*c217d954SCole Faust if(_projection_bias != nullptr)
348*c217d954SCole Faust {
349*c217d954SCole Faust _projection_bias_add.configure(_projection_bias, &_projection_eff_bias, &_projection_eff_bias, ConvertPolicy::SATURATE);
350*c217d954SCole Faust }
351*c217d954SCole Faust }
352*c217d954SCole Faust
353*c217d954SCole Faust // Pre-transpose weights to be used in GEMM.
354*c217d954SCole Faust _transpose_input_to_forget_weights.configure(input_to_forget_weights, &_input_to_forget_weights_transposed);
355*c217d954SCole Faust _transpose_input_to_cell_weights.configure(input_to_cell_weights, &_input_to_cell_weights_transposed);
356*c217d954SCole Faust _transpose_input_to_output_weights.configure(input_to_output_weights, &_input_to_output_weights_transposed);
357*c217d954SCole Faust _transpose_recurrent_to_forget_weights.configure(recurrent_to_forget_weights, &_recurrent_to_forget_weights_transposed);
358*c217d954SCole Faust _transpose_recurrent_to_cell_weights.configure(recurrent_to_cell_weights, &_recurrent_to_cell_weights_transposed);
359*c217d954SCole Faust _transpose_recurrent_to_output_weights.configure(recurrent_to_output_weights, &_recurrent_to_output_weights_transposed);
360*c217d954SCole Faust if(!_has_cifg)
361*c217d954SCole Faust {
362*c217d954SCole Faust _transpose_input_to_input_weights.configure(lstm_params.input_to_input_weights(), &_input_to_input_weights_transposed);
363*c217d954SCole Faust _transpose_recurrent_to_input_weights.configure(lstm_params.recurrent_to_input_weights(), &_recurrent_to_input_weights_transposed);
364*c217d954SCole Faust }
365*c217d954SCole Faust if(_has_projection)
366*c217d954SCole Faust {
367*c217d954SCole Faust _transpose_projection_weights.configure(_projection_weights, &_projection_weights_transposed);
368*c217d954SCole Faust }
369*c217d954SCole Faust
370*c217d954SCole Faust GEMMLowpOutputStageInfo gemmlowp_info;
371*c217d954SCole Faust gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
372*c217d954SCole Faust gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
373*c217d954SCole Faust gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
374*c217d954SCole Faust gemmlowp_info.output_data_type = DataType::QSYMM16;
375*c217d954SCole Faust
376*c217d954SCole Faust const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
377*c217d954SCole Faust // Forget gate.
378*c217d954SCole Faust const TensorInfo forget_gate_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
379*c217d954SCole Faust const float input_to_forget_scale = input_to_forget_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
380*c217d954SCole Faust configure_mm(_mm_input_to_forget, _input_to_forget_outstage, gemmlowp_info,
381*c217d954SCole Faust input, &_input_to_forget_weights_transposed, &_input_to_forget_eff_bias,
382*c217d954SCole Faust &_mm_input_to_forget_res, &_input_to_forget_outstage_res, input_to_forget_scale,
383*c217d954SCole Faust mm_out_info, forget_gate_outstage_info);
384*c217d954SCole Faust
385*c217d954SCole Faust const float recurrent_to_forget_scale = recurrent_to_forget_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
386*c217d954SCole Faust configure_mm(_mm_recurrent_to_forget, _recurrent_to_forget_outstage, gemmlowp_info,
387*c217d954SCole Faust output_state_in, &_recurrent_to_forget_weights_transposed, &_recurrent_to_forget_eff_bias,
388*c217d954SCole Faust &_mm_recurrent_to_forget_res, &_recurrent_to_forget_outstage_res, recurrent_to_forget_scale,
389*c217d954SCole Faust mm_out_info, forget_gate_outstage_info);
390*c217d954SCole Faust
391*c217d954SCole Faust _accumulate_input_recurrent_forget.configure(&_input_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
392*c217d954SCole Faust _input_to_forget_outstage_res.allocator()->allocate();
393*c217d954SCole Faust
394*c217d954SCole Faust if(_has_peephole)
395*c217d954SCole Faust {
396*c217d954SCole Faust _mul_cell_to_forget_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
397*c217d954SCole Faust _memory_group.manage(&_mul_cell_to_forget_res);
398*c217d954SCole Faust _pixelwise_mul_cell_to_forget.configure(cell_state_in, lstm_params.cell_to_forget_weights(), &_mul_cell_to_forget_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
399*c217d954SCole Faust _cell_to_forget_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_forget_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)));
400*c217d954SCole Faust _memory_group.manage(&_cell_to_forget_outstage_res);
401*c217d954SCole Faust const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->info()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
402*c217d954SCole Faust quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
403*c217d954SCole Faust _cell_to_forget_outstage.configure(&_mul_cell_to_forget_res, nullptr, &_cell_to_forget_outstage_res, gemmlowp_info);
404*c217d954SCole Faust _mul_cell_to_forget_res.allocator()->allocate();
405*c217d954SCole Faust _accumulate_cell_forget.configure(&_recurrent_to_forget_outstage_res, &_cell_to_forget_outstage_res, &_recurrent_to_forget_outstage_res, ConvertPolicy::SATURATE);
406*c217d954SCole Faust _cell_to_forget_outstage_res.allocator()->allocate();
407*c217d954SCole Faust }
408*c217d954SCole Faust
409*c217d954SCole Faust Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
410*c217d954SCole Faust
411*c217d954SCole Faust if(_has_layer_norm)
412*c217d954SCole Faust {
413*c217d954SCole Faust configure_layer_norm(LayerNormGate::Forget, forget_activation_input);
414*c217d954SCole Faust forget_activation_input->allocator()->allocate();
415*c217d954SCole Faust forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
416*c217d954SCole Faust }
417*c217d954SCole Faust
418*c217d954SCole Faust // Output quantization info of Sigmoid and Tanh activations
419*c217d954SCole Faust const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
420*c217d954SCole Faust const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
421*c217d954SCole Faust
422*c217d954SCole Faust _memory_group.manage(&_forget_gate);
423*c217d954SCole Faust _forget_gate.allocator()->init(forget_gate_info);
424*c217d954SCole Faust _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
425*c217d954SCole Faust forget_activation_input->allocator()->allocate();
426*c217d954SCole Faust
427*c217d954SCole Faust // Modulation gate.
428*c217d954SCole Faust const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
429*c217d954SCole Faust const float input_to_cell_scale = input_to_cell_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
430*c217d954SCole Faust configure_mm(_mm_input_to_cell, _input_to_cell_outstage, gemmlowp_info,
431*c217d954SCole Faust input, &_input_to_cell_weights_transposed, &_input_to_cell_eff_bias,
432*c217d954SCole Faust &_mm_input_to_cell_res, &_input_to_cell_outstage_res, input_to_cell_scale,
433*c217d954SCole Faust mm_out_info, cell_outstage_info);
434*c217d954SCole Faust
435*c217d954SCole Faust const float recurrent_to_cell_scale = recurrent_to_cell_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
436*c217d954SCole Faust configure_mm(_mm_recurrent_to_cell, _recurrent_to_cell_outstage, gemmlowp_info,
437*c217d954SCole Faust output_state_in, &_recurrent_to_cell_weights_transposed, &_recurrent_to_cell_eff_bias,
438*c217d954SCole Faust &_mm_recurrent_to_cell_res, &_recurrent_to_cell_outstage_res, recurrent_to_cell_scale,
439*c217d954SCole Faust mm_out_info, cell_outstage_info);
440*c217d954SCole Faust
441*c217d954SCole Faust _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE);
442*c217d954SCole Faust _input_to_cell_outstage_res.allocator()->allocate();
443*c217d954SCole Faust
444*c217d954SCole Faust Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
445*c217d954SCole Faust
446*c217d954SCole Faust if(_has_layer_norm)
447*c217d954SCole Faust {
448*c217d954SCole Faust configure_layer_norm(LayerNormGate::Cell, cell_activation_input);
449*c217d954SCole Faust cell_activation_input->allocator()->allocate();
450*c217d954SCole Faust cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
451*c217d954SCole Faust }
452*c217d954SCole Faust
453*c217d954SCole Faust const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
454*c217d954SCole Faust
455*c217d954SCole Faust _memory_group.manage(&_cell_gate);
456*c217d954SCole Faust _cell_gate.allocator()->init(cell_gate_info);
457*c217d954SCole Faust _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
458*c217d954SCole Faust cell_activation_input->allocator()->allocate();
459*c217d954SCole Faust
460*c217d954SCole Faust // Input gate.
461*c217d954SCole Faust const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
462*c217d954SCole Faust _input_gate.allocator()->init(input_gate_info);
463*c217d954SCole Faust _memory_group.manage(&_input_gate);
464*c217d954SCole Faust if(_has_cifg)
465*c217d954SCole Faust {
466*c217d954SCole Faust _ones.allocator()->init(*_forget_gate.info());
467*c217d954SCole Faust _input_gate_sub.configure(&_ones, &_forget_gate, &_input_gate, ConvertPolicy::SATURATE);
468*c217d954SCole Faust _ones.allocator()->allocate();
469*c217d954SCole Faust }
470*c217d954SCole Faust else
471*c217d954SCole Faust {
472*c217d954SCole Faust const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
473*c217d954SCole Faust const float input_to_input_scale = _input_to_input_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
474*c217d954SCole Faust configure_mm(_mm_input_to_input, _input_to_input_outstage, gemmlowp_info,
475*c217d954SCole Faust input, &_input_to_input_weights_transposed, &_input_to_input_eff_bias,
476*c217d954SCole Faust &_mm_input_to_input_res, &_input_to_input_outstage_res, input_to_input_scale,
477*c217d954SCole Faust mm_out_info, input_outstage_info);
478*c217d954SCole Faust
479*c217d954SCole Faust const float recurrent_to_input_scale = _recurrent_to_input_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
480*c217d954SCole Faust configure_mm(_mm_recurrent_to_input, _recurrent_to_input_outstage, gemmlowp_info,
481*c217d954SCole Faust output_state_in, &_recurrent_to_input_weights_transposed, &_recurrent_to_input_eff_bias,
482*c217d954SCole Faust &_mm_recurrent_to_input_res, &_recurrent_to_input_outstage_res, recurrent_to_input_scale,
483*c217d954SCole Faust mm_out_info, input_outstage_info);
484*c217d954SCole Faust _accumulate_input_recurrent_input.configure(&_input_to_input_outstage_res, &_recurrent_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
485*c217d954SCole Faust _input_to_input_outstage_res.allocator()->allocate();
486*c217d954SCole Faust
487*c217d954SCole Faust if(_has_peephole)
488*c217d954SCole Faust {
489*c217d954SCole Faust _mul_cell_to_input_res.allocator()->init(TensorInfo(cell_state_in->info()->tensor_shape(), 1, DataType::S32));
490*c217d954SCole Faust _memory_group.manage(&_mul_cell_to_input_res);
491*c217d954SCole Faust _pixelwise_mul_cell_to_input.configure(cell_state_in, lstm_params.cell_to_input_weights(), &_mul_cell_to_input_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
492*c217d954SCole Faust const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->info()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
493*c217d954SCole Faust quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
494*c217d954SCole Faust _cell_to_input_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_input_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0)));
495*c217d954SCole Faust _memory_group.manage(&_cell_to_input_outstage_res);
496*c217d954SCole Faust _cell_to_input_outstage.configure(&_mul_cell_to_input_res, nullptr, &_cell_to_input_outstage_res, gemmlowp_info);
497*c217d954SCole Faust _mul_cell_to_input_res.allocator()->allocate();
498*c217d954SCole Faust _accumulate_cell_input.configure(&_recurrent_to_input_outstage_res, &_cell_to_input_outstage_res, &_recurrent_to_input_outstage_res, ConvertPolicy::SATURATE);
499*c217d954SCole Faust _cell_to_input_outstage_res.allocator()->allocate();
500*c217d954SCole Faust }
501*c217d954SCole Faust
502*c217d954SCole Faust Tensor *input_activation_input = &_recurrent_to_input_outstage_res;
503*c217d954SCole Faust
504*c217d954SCole Faust if(_has_layer_norm)
505*c217d954SCole Faust {
506*c217d954SCole Faust configure_layer_norm(LayerNormGate::Input, input_activation_input);
507*c217d954SCole Faust input_activation_input->allocator()->allocate();
508*c217d954SCole Faust input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
509*c217d954SCole Faust }
510*c217d954SCole Faust
511*c217d954SCole Faust _input_gate_sigmoid.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
512*c217d954SCole Faust input_activation_input->allocator()->allocate();
513*c217d954SCole Faust }
514*c217d954SCole Faust // Cell.
515*c217d954SCole Faust // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
516*c217d954SCole Faust _pixelwise_mul_forget_cell.configure(&_forget_gate, cell_state_in, &_forget_gate, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
517*c217d954SCole Faust const float cell_gate_scale = _cell_gate.info()->quantization_info().uniform().scale;
518*c217d954SCole Faust const float mul_input_cell_scale = cell_gate_scale * std::pow(2, 15 + cell_shift);
519*c217d954SCole Faust const TensorInfo mul_input_cell_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(mul_input_cell_scale, 0));
520*c217d954SCole Faust _memory_group.manage(&_mul_input_cell_res);
521*c217d954SCole Faust _mul_input_cell_res.allocator()->init(mul_input_cell_info);
522*c217d954SCole Faust _pixelwise_mul_input_cell.configure(&_input_gate, &_cell_gate, &_mul_input_cell_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
523*c217d954SCole Faust _cell_gate.allocator()->allocate();
524*c217d954SCole Faust _add_forget_cell.configure(&_forget_gate, &_mul_input_cell_res, cell_state_out, ConvertPolicy::SATURATE);
525*c217d954SCole Faust _mul_input_cell_res.allocator()->allocate();
526*c217d954SCole Faust _forget_gate.allocator()->allocate();
527*c217d954SCole Faust if(_has_cell_clipping)
528*c217d954SCole Faust {
529*c217d954SCole Faust _cell_clip.configure(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip, quantized_cell_clip));
530*c217d954SCole Faust }
531*c217d954SCole Faust // Output gate.
532*c217d954SCole Faust const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
533*c217d954SCole Faust const float input_to_output_scale = input_to_output_weights->info()->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
534*c217d954SCole Faust configure_mm(_mm_input_to_output, _input_to_output_outstage, gemmlowp_info,
535*c217d954SCole Faust input, &_input_to_output_weights_transposed, &_input_to_output_eff_bias,
536*c217d954SCole Faust &_mm_input_to_output_res, &_input_to_output_outstage_res, input_to_output_scale,
537*c217d954SCole Faust mm_out_info, output_outstage_info);
538*c217d954SCole Faust
539*c217d954SCole Faust const float recurrent_to_output_scale = recurrent_to_output_weights->info()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
540*c217d954SCole Faust configure_mm(_mm_recurrent_to_output, _recurrent_to_output_outstage, gemmlowp_info,
541*c217d954SCole Faust output_state_in, &_recurrent_to_output_weights_transposed, &_recurrent_to_output_eff_bias,
542*c217d954SCole Faust &_mm_recurrent_to_output_res, &_recurrent_to_output_outstage_res, recurrent_to_output_scale,
543*c217d954SCole Faust mm_out_info, output_outstage_info);
544*c217d954SCole Faust
545*c217d954SCole Faust _accumulate_input_recurrent_output.configure(&_recurrent_to_output_outstage_res, &_input_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
546*c217d954SCole Faust _input_to_output_outstage_res.allocator()->allocate();
547*c217d954SCole Faust
548*c217d954SCole Faust if(_has_peephole)
549*c217d954SCole Faust {
550*c217d954SCole Faust // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
551*c217d954SCole Faust // Here we are not using the output stage because all operations are done in float
552*c217d954SCole Faust _mul_cell_to_output_res.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::S32));
553*c217d954SCole Faust _memory_group.manage(&_mul_cell_to_output_res);
554*c217d954SCole Faust _pixelwise_mul_cell_to_output.configure(cell_state_out, lstm_params.cell_to_output_weights(), &_mul_cell_to_output_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
555*c217d954SCole Faust
556*c217d954SCole Faust const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->info()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
557*c217d954SCole Faust quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift);
558*c217d954SCole Faust _cell_to_output_outstage_res.allocator()->init(TensorInfo(_mul_cell_to_output_res.info()->tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0)));
559*c217d954SCole Faust _memory_group.manage(&_cell_to_output_outstage_res);
560*c217d954SCole Faust _cell_to_output_outstage.configure(&_mul_cell_to_output_res, nullptr, &_cell_to_output_outstage_res, gemmlowp_info);
561*c217d954SCole Faust _mul_cell_to_output_res.allocator()->allocate();
562*c217d954SCole Faust
563*c217d954SCole Faust _accumulate_cell_to_output.configure(&_recurrent_to_output_outstage_res, &_cell_to_output_outstage_res, &_recurrent_to_output_outstage_res, ConvertPolicy::SATURATE);
564*c217d954SCole Faust _cell_to_output_outstage_res.allocator()->allocate();
565*c217d954SCole Faust }
566*c217d954SCole Faust
567*c217d954SCole Faust Tensor *output_activation_input = &_recurrent_to_output_outstage_res;
568*c217d954SCole Faust
569*c217d954SCole Faust if(_has_layer_norm)
570*c217d954SCole Faust {
571*c217d954SCole Faust configure_layer_norm(LayerNormGate::Output, output_activation_input);
572*c217d954SCole Faust output_activation_input->allocator()->allocate();
573*c217d954SCole Faust output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
574*c217d954SCole Faust }
575*c217d954SCole Faust const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
576*c217d954SCole Faust
577*c217d954SCole Faust _memory_group.manage(&_output_gate);
578*c217d954SCole Faust _output_gate.allocator()->init(output_gate_info);
579*c217d954SCole Faust _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
580*c217d954SCole Faust output_activation_input->allocator()->allocate();
581*c217d954SCole Faust
582*c217d954SCole Faust // Hidden.
583*c217d954SCole Faust _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
584*c217d954SCole Faust // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
585*c217d954SCole Faust _memory_group.manage(&_hidden_mul_res);
586*c217d954SCole Faust const TensorInfo hidden_mul_res(_input_gate.info()->tensor_shape(), 1, DataType::S32);
587*c217d954SCole Faust _hidden_mul_res.allocator()->init(hidden_mul_res);
588*c217d954SCole Faust _pixelwise_mul_hidden.configure(&_output_gate, &_input_gate, &_hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
589*c217d954SCole Faust _output_gate.allocator()->allocate();
590*c217d954SCole Faust _input_gate.allocator()->allocate();
591*c217d954SCole Faust const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
592*c217d954SCole Faust quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true);
593*c217d954SCole Faust gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
594*c217d954SCole Faust gemmlowp_info.output_data_type = output_state_in->info()->data_type();
595*c217d954SCole Faust
596*c217d954SCole Faust _projection_tensor_copy_required = (num_units != output_size);
597*c217d954SCole Faust ITensor *hidden_gate_result = output_state_out;
598*c217d954SCole Faust
599*c217d954SCole Faust _memory_group.manage(&_hidden_gate);
600*c217d954SCole Faust
601*c217d954SCole Faust if(_projection_tensor_copy_required)
602*c217d954SCole Faust {
603*c217d954SCole Faust _hidden_gate.allocator()->init(*output_state_out->info());
604*c217d954SCole Faust _hidden_gate.info()->set_tensor_shape(_hidden_mul_res.info()->tensor_shape());
605*c217d954SCole Faust hidden_gate_result = &_hidden_gate;
606*c217d954SCole Faust }
607*c217d954SCole Faust
608*c217d954SCole Faust _hidden_outstage.configure(&_hidden_mul_res, nullptr, hidden_gate_result, gemmlowp_info);
609*c217d954SCole Faust _hidden_mul_res.allocator()->allocate();
610*c217d954SCole Faust
611*c217d954SCole Faust // Projection.
612*c217d954SCole Faust if(_has_projection)
613*c217d954SCole Faust {
614*c217d954SCole Faust const TensorInfo projection_outstage_info(*output_state_out->info());
615*c217d954SCole Faust const UniformQuantizationInfo qprojection = _projection_weights->info()->quantization_info().uniform();
616*c217d954SCole Faust const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
617*c217d954SCole Faust gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
618*c217d954SCole Faust gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
619*c217d954SCole Faust gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
620*c217d954SCole Faust gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
621*c217d954SCole Faust
622*c217d954SCole Faust TensorInfo projection_mm_out_info{ mm_out_info };
623*c217d954SCole Faust projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
624*c217d954SCole Faust
625*c217d954SCole Faust configure_mm(_mm_projection, _projection_outstage, gemmlowp_info,
626*c217d954SCole Faust hidden_gate_result, &_projection_weights_transposed, &_projection_eff_bias,
627*c217d954SCole Faust &_mm_projection_res, &_projection_outstage_res, projection_scale,
628*c217d954SCole Faust projection_mm_out_info, projection_outstage_info);
629*c217d954SCole Faust
630*c217d954SCole Faust ITensor *accumulate_destination = output_state_out;
631*c217d954SCole Faust
632*c217d954SCole Faust if(_projection_tensor_copy_required)
633*c217d954SCole Faust {
634*c217d954SCole Faust _hidden_gate.allocator()->allocate();
635*c217d954SCole Faust _projection_accumulate_res.allocator()->init(*output_state_in->info());
636*c217d954SCole Faust _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape());
637*c217d954SCole Faust _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res);
638*c217d954SCole Faust accumulate_destination = &_projection_accumulate_res;
639*c217d954SCole Faust }
640*c217d954SCole Faust
641*c217d954SCole Faust _accumulate_projection.configure(&_projection_outstage_res, accumulate_destination, accumulate_destination, ConvertPolicy::SATURATE);
642*c217d954SCole Faust _projection_outstage_res.allocator()->allocate();
643*c217d954SCole Faust
644*c217d954SCole Faust if(_projection_tensor_copy_required)
645*c217d954SCole Faust {
646*c217d954SCole Faust _projection_accumulate_to_output_copy.configure(_projection_accumulate_res, *output_state_out);
647*c217d954SCole Faust _projection_accumulate_res.allocator()->allocate();
648*c217d954SCole Faust }
649*c217d954SCole Faust
650*c217d954SCole Faust int8_t quantized_projection_clip{ 0 };
651*c217d954SCole Faust if(lstm_params.projection_clip() > 0.0f)
652*c217d954SCole Faust {
653*c217d954SCole Faust quantized_projection_clip = utility::clamp<int8_t>(lstm_params.projection_clip() / qprojection.scale, -128, 127);
654*c217d954SCole Faust }
655*c217d954SCole Faust
656*c217d954SCole Faust if(quantized_projection_clip > 0)
657*c217d954SCole Faust {
658*c217d954SCole Faust _projection_clip.configure(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip, quantized_projection_clip));
659*c217d954SCole Faust _has_projection_clipping = true;
660*c217d954SCole Faust }
661*c217d954SCole Faust }
662*c217d954SCole Faust else
663*c217d954SCole Faust {
664*c217d954SCole Faust if(_projection_tensor_copy_required)
665*c217d954SCole Faust {
666*c217d954SCole Faust _hidden_to_output_copy.configure(_hidden_gate, *output_state_out);
667*c217d954SCole Faust _hidden_gate.allocator()->allocate();
668*c217d954SCole Faust }
669*c217d954SCole Faust }
670*c217d954SCole Faust
671*c217d954SCole Faust // Copy output_state_out to output
672*c217d954SCole Faust _copy_output.configure(output_state_out, output);
673*c217d954SCole Faust }
674*c217d954SCole Faust
validate(const ITensorInfo * input,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * cell_state_in,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_out,const ITensorInfo * output_state_out,const ITensorInfo * output,const LSTMParams<ITensorInfo> & lstm_params)675*c217d954SCole Faust Status NEQLSTMLayer::validate(const ITensorInfo *input,
676*c217d954SCole Faust const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
677*c217d954SCole Faust const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
678*c217d954SCole Faust const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
679*c217d954SCole Faust const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
680*c217d954SCole Faust const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
681*c217d954SCole Faust const LSTMParams<ITensorInfo> &lstm_params)
682*c217d954SCole Faust {
683*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
684*c217d954SCole Faust recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
685*c217d954SCole Faust cell_state_out, output_state_out, output);
686*c217d954SCole Faust
687*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
688*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
689*c217d954SCole Faust
690*c217d954SCole Faust const unsigned int input_size = input->dimension(0);
691*c217d954SCole Faust const unsigned int batch_size = input->dimension(1);
692*c217d954SCole Faust const unsigned int num_units = input_to_output_weights->dimension(1);
693*c217d954SCole Faust const unsigned int output_size = output_state_out->dimension(_out_state_output_size_dimension_idx);
694*c217d954SCole Faust
695*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
696*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->dimension(0) != input_size);
697*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_output_weights, input_to_forget_weights, input_to_cell_weights);
698*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
699*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->dimension(1) != num_units);
700*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights);
701*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_to_forget_weights, 1, DataType::QASYMM8_SIGNED, DataType::QSYMM8);
702*c217d954SCole Faust
703*c217d954SCole Faust // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
704*c217d954SCole Faust if (input_to_forget_weights->data_type() == DataType::QSYMM8)
705*c217d954SCole Faust {
706*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_cell_weights, input_to_output_weights,
707*c217d954SCole Faust recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
708*c217d954SCole Faust }
709*c217d954SCole Faust else
710*c217d954SCole Faust {
711*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
712*c217d954SCole Faust recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
713*c217d954SCole Faust }
714*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
715*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->dimension(0) != num_units);
716*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, cell_bias, output_gate_bias);
717*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(forget_gate_bias, 1, DataType::S32);
718*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, cell_bias, output_gate_bias);
719*c217d954SCole Faust
720*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() != 2);
721*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(0) != num_units);
722*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->dimension(1) != batch_size);
723*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(cell_state_in, 1, DataType::QSYMM16);
724*c217d954SCole Faust
725*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() != 2);
726*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(0) != output_size);
727*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->dimension(1) != batch_size);
728*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_in);
729*c217d954SCole Faust
730*c217d954SCole Faust // Check whether peephole weights are all there or none
731*c217d954SCole Faust if(lstm_params.has_peephole_opt())
732*c217d954SCole Faust {
733*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
734*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
735*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
736*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->dimension(0) != num_units);
737*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
738*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights());
739*c217d954SCole Faust
740*c217d954SCole Faust if(!lstm_params.has_cifg_opt())
741*c217d954SCole Faust {
742*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
743*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
744*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_input_weights());
745*c217d954SCole Faust }
746*c217d954SCole Faust }
747*c217d954SCole Faust
748*c217d954SCole Faust const UniformQuantizationInfo qinput = input->quantization_info().uniform();
749*c217d954SCole Faust const UniformQuantizationInfo qcell_state_in = cell_state_in->quantization_info().uniform();
750*c217d954SCole Faust const UniformQuantizationInfo qoutput_state_in = output_state_in->quantization_info().uniform();
751*c217d954SCole Faust
752*c217d954SCole Faust // Calculate and decompose effective scales for optimizing matmul calculation
753*c217d954SCole Faust const int32_t cell_shift = log2(qcell_state_in.scale);
754*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(cell_shift > -9);
755*c217d954SCole Faust
756*c217d954SCole Faust // Calculate quantized parameters for clipping.
757*c217d954SCole Faust int16_t quantized_cell_clip = 0;
758*c217d954SCole Faust if(lstm_params.cell_clip() > 0.0f)
759*c217d954SCole Faust {
760*c217d954SCole Faust quantized_cell_clip = quantize_qsymm16(lstm_params.cell_clip(), qcell_state_in);
761*c217d954SCole Faust }
762*c217d954SCole Faust
763*c217d954SCole Faust // Precompute effective bias for optimizing the matmul computations.
764*c217d954SCole Faust const TensorInfo eff_bias_info(TensorShape(num_units), 1, DataType::S32);
765*c217d954SCole Faust const TensorInfo projection_eff_bias_info(TensorShape(output_size), 1, DataType::S32);
766*c217d954SCole Faust if(!lstm_params.has_cifg_opt())
767*c217d954SCole Faust {
768*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.input_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
769*c217d954SCole Faust -qinput.offset, true)));
770*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.recurrent_to_input_weights(), &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
771*c217d954SCole Faust -qoutput_state_in.offset,
772*c217d954SCole Faust true)));
773*c217d954SCole Faust }
774*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
775*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_forget_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
776*c217d954SCole Faust -qoutput_state_in.offset, true)));
777*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
778*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_cell_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qoutput_state_in.offset,
779*c217d954SCole Faust true)));
780*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(input_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false, -qinput.offset, true)));
781*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(recurrent_to_output_weights, &eff_bias_info, GEMMLowpReductionKernelInfo(num_units, false,
782*c217d954SCole Faust -qoutput_state_in.offset, true)));
783*c217d954SCole Faust if(lstm_params.has_projection())
784*c217d954SCole Faust {
785*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmLowpMatrixAReductionKernel::validate(lstm_params.projection_weights(), &projection_eff_bias_info, GEMMLowpReductionKernelInfo(output_size, false,
786*c217d954SCole Faust lstm_params.hidden_state_zero(),
787*c217d954SCole Faust true)));
788*c217d954SCole Faust if(lstm_params.projection_bias() != nullptr)
789*c217d954SCole Faust {
790*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.projection_bias(), 1, DataType::S32);
791*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(lstm_params.projection_bias(), &projection_eff_bias_info, &projection_eff_bias_info, ConvertPolicy::SATURATE));
792*c217d954SCole Faust }
793*c217d954SCole Faust }
794*c217d954SCole Faust
795*c217d954SCole Faust const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_cell_weights->data_type(), input_to_cell_weights->quantization_info());
796*c217d954SCole Faust const TensorInfo input_to_output_weights_transposed(TensorShape(num_units, input_size), 1, input_to_output_weights->data_type(), input_to_output_weights->quantization_info());
797*c217d954SCole Faust const TensorInfo recurrent_to_forget_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
798*c217d954SCole Faust const TensorInfo recurrent_to_cell_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_cell_weights->data_type(), recurrent_to_cell_weights->quantization_info());
799*c217d954SCole Faust const TensorInfo recurrent_to_output_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_output_weights->data_type(), recurrent_to_output_weights->quantization_info());
800*c217d954SCole Faust const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info());
801*c217d954SCole Faust
802*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed));
803*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_to_output_weights_transposed));
804*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_to_forget_weights_transposed));
805*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_to_cell_weights_transposed));
806*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_to_output_weights_transposed));
807*c217d954SCole Faust if(!lstm_params.has_cifg_opt())
808*c217d954SCole Faust {
809*c217d954SCole Faust const TensorInfo recurrent_to_input_weights_transposed(TensorShape(num_units, output_size), 1,
810*c217d954SCole Faust recurrent_to_forget_weights->data_type(), lstm_params.recurrent_to_input_weights()->quantization_info());
811*c217d954SCole Faust const TensorInfo input_to_input_weights_transposed(TensorShape(num_units, input_size), 1,
812*c217d954SCole Faust lstm_params.input_to_input_weights()->data_type(), lstm_params.input_to_input_weights()->quantization_info());
813*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_to_input_weights_transposed));
814*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_to_input_weights_transposed));
815*c217d954SCole Faust }
816*c217d954SCole Faust if(lstm_params.has_projection())
817*c217d954SCole Faust {
818*c217d954SCole Faust const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
819*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.projection_weights(), &projection_weights_transposed));
820*c217d954SCole Faust }
821*c217d954SCole Faust
822*c217d954SCole Faust GEMMLowpOutputStageInfo gemmlowp_info;
823*c217d954SCole Faust gemmlowp_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
824*c217d954SCole Faust gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int16_t>::lowest();
825*c217d954SCole Faust gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
826*c217d954SCole Faust gemmlowp_info.output_data_type = DataType::QSYMM16;
827*c217d954SCole Faust
828*c217d954SCole Faust const bool has_layer_norm = lstm_params.use_layer_norm();
829*c217d954SCole Faust
830*c217d954SCole Faust // Forget gate.
831*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_intermediate_scale() == 0);
832*c217d954SCole Faust const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
833*c217d954SCole Faust const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
834*c217d954SCole Faust const float input_to_forget_scale = input_to_forget_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.forget_intermediate_scale();
835*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_forget_scale, &mm_out_info, &forget_outstage_info));
836*c217d954SCole Faust
837*c217d954SCole Faust const float recurrent_to_forget_scale = recurrent_to_forget_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.forget_intermediate_scale();
838*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_forget_scale, &mm_out_info, &forget_outstage_info));
839*c217d954SCole Faust
840*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
841*c217d954SCole Faust
842*c217d954SCole Faust if(lstm_params.has_peephole_opt())
843*c217d954SCole Faust {
844*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_forget_weights(), 1, DataType::QSYMM16);
845*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
846*c217d954SCole Faust RoundingPolicy::TO_ZERO));
847*c217d954SCole Faust const float cell_to_forget_scale = std::pow(2, cell_shift) * lstm_params.cell_to_forget_weights()->quantization_info().uniform().scale / lstm_params.forget_intermediate_scale();
848*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_forget_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
849*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, nullptr, &forget_outstage_info, gemmlowp_info));
850*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
851*c217d954SCole Faust }
852*c217d954SCole Faust
853*c217d954SCole Faust if(has_layer_norm)
854*c217d954SCole Faust {
855*c217d954SCole Faust const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
856*c217d954SCole Faust const ITensorInfo *b_info = forget_gate_bias;
857*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
858*c217d954SCole Faust }
859*c217d954SCole Faust
860*c217d954SCole Faust // Output quantization info of Sigmoid and Tanh activations
861*c217d954SCole Faust const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
862*c217d954SCole Faust const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
863*c217d954SCole Faust
864*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
865*c217d954SCole Faust
866*c217d954SCole Faust // Modulation gate.
867*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_intermediate_scale() == 0);
868*c217d954SCole Faust const TensorInfo cell_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
869*c217d954SCole Faust const float input_to_cell_scale = input_to_cell_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.cell_intermediate_scale();
870*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_cell_scale, &mm_out_info, &cell_outstage_info));
871*c217d954SCole Faust
872*c217d954SCole Faust const float recurrent_to_cell_scale = recurrent_to_cell_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.cell_intermediate_scale();
873*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_cell_scale, &mm_out_info, &cell_outstage_info));
874*c217d954SCole Faust
875*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
876*c217d954SCole Faust
877*c217d954SCole Faust if(has_layer_norm)
878*c217d954SCole Faust {
879*c217d954SCole Faust const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
880*c217d954SCole Faust const ITensorInfo *b_info = cell_bias;
881*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
882*c217d954SCole Faust }
883*c217d954SCole Faust const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
884*c217d954SCole Faust
885*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
886*c217d954SCole Faust
887*c217d954SCole Faust // Input gate.
888*c217d954SCole Faust const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
889*c217d954SCole Faust if(lstm_params.has_cifg_opt())
890*c217d954SCole Faust {
891*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MSG(lstm_params.input_gate_bias() != nullptr, "Input gate bias must not be present when CIFG is used");
892*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticSubtraction::validate(&input_gate_info, &forget_gate_info, &forget_gate_info, ConvertPolicy::SATURATE));
893*c217d954SCole Faust }
894*c217d954SCole Faust else
895*c217d954SCole Faust {
896*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias());
897*c217d954SCole Faust
898*c217d954SCole Faust // If the input_to_forget_weights data type is DataType::QSYMM8 then it can never match the other weights as they are all DataType::QASYMM8_SIGNED
899*c217d954SCole Faust if (input_to_forget_weights->data_type() == DataType::QSYMM8)
900*c217d954SCole Faust {
901*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
902*c217d954SCole Faust }
903*c217d954SCole Faust else
904*c217d954SCole Faust {
905*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_to_forget_weights, lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights());
906*c217d954SCole Faust }
907*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input_to_forget_weights, lstm_params.input_to_input_weights());
908*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(recurrent_to_forget_weights, lstm_params.recurrent_to_input_weights());
909*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(forget_gate_bias, lstm_params.input_gate_bias());
910*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(forget_gate_bias, lstm_params.input_gate_bias());
911*c217d954SCole Faust
912*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_intermediate_scale() == 0);
913*c217d954SCole Faust const TensorInfo input_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.input_intermediate_scale(), 0));
914*c217d954SCole Faust const float input_to_input_scale = lstm_params.input_to_input_weights()->quantization_info().uniform().scale * qinput.scale / lstm_params.input_intermediate_scale();
915*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_input_scale, &mm_out_info, &input_outstage_info));
916*c217d954SCole Faust
917*c217d954SCole Faust const float recurrent_to_input_scale = lstm_params.recurrent_to_input_weights()->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.input_intermediate_scale();
918*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_input_scale, &mm_out_info, &input_outstage_info));
919*c217d954SCole Faust
920*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
921*c217d954SCole Faust
922*c217d954SCole Faust if(lstm_params.has_peephole_opt())
923*c217d954SCole Faust {
924*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &mm_out_info, 1.f, ConvertPolicy::SATURATE,
925*c217d954SCole Faust RoundingPolicy::TO_ZERO));
926*c217d954SCole Faust const float cell_to_input_scale = std::pow(2, cell_shift) * lstm_params.cell_to_input_weights()->quantization_info().uniform().scale / lstm_params.input_intermediate_scale();
927*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_input_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
928*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&mm_out_info, &eff_bias_info, &input_outstage_info, gemmlowp_info));
929*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
930*c217d954SCole Faust }
931*c217d954SCole Faust
932*c217d954SCole Faust if(has_layer_norm)
933*c217d954SCole Faust {
934*c217d954SCole Faust const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
935*c217d954SCole Faust const ITensorInfo *b_info = lstm_params.input_gate_bias();
936*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info));
937*c217d954SCole Faust }
938*c217d954SCole Faust
939*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
940*c217d954SCole Faust }
941*c217d954SCole Faust // Cell.
942*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate_info, cell_state_in, &forget_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
943*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate_info, cell_state_in, &cell_gate_info, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
944*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate_info, &cell_gate_info, cell_state_out, ConvertPolicy::SATURATE));
945*c217d954SCole Faust if(quantized_cell_clip > 0)
946*c217d954SCole Faust {
947*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_cell_clip,
948*c217d954SCole Faust quantized_cell_clip)));
949*c217d954SCole Faust }
950*c217d954SCole Faust // Output gate.
951*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_intermediate_scale() == 0);
952*c217d954SCole Faust const TensorInfo output_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.output_intermediate_scale(), 0));
953*c217d954SCole Faust const float input_to_output_scale = input_to_output_weights->quantization_info().uniform().scale * qinput.scale / lstm_params.output_intermediate_scale();
954*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, input, &input_weights_transposed, &eff_bias_info, input_to_output_scale, &mm_out_info, &output_outstage_info));
955*c217d954SCole Faust
956*c217d954SCole Faust const float recurrent_to_output_scale = recurrent_to_output_weights->quantization_info().uniform().scale * qoutput_state_in.scale / lstm_params.output_intermediate_scale();
957*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, output_state_in, &recurrent_weights_transposed, &eff_bias_info, recurrent_to_output_scale, &mm_out_info, &output_outstage_info));
958*c217d954SCole Faust
959*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
960*c217d954SCole Faust if(lstm_params.has_peephole_opt())
961*c217d954SCole Faust {
962*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lstm_params.cell_to_output_weights(), 1, DataType::QSYMM16);
963*c217d954SCole Faust // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplication
964*c217d954SCole Faust // Here we are not using the output stage because all operations are done in float
965*c217d954SCole Faust // const float cell_to_output_scale = std::pow(2, cell_shift) * lstm_params.cell_to_output_weights()->quantization_info().uniform().scale / lstm_params.output_intermediate_scale();
966*c217d954SCole Faust // ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(cell_to_output_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
967*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(cell_state_out, lstm_params.cell_to_output_weights(), &output_outstage_info, 1.f, ConvertPolicy::SATURATE,
968*c217d954SCole Faust RoundingPolicy::TO_ZERO));
969*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
970*c217d954SCole Faust }
971*c217d954SCole Faust
972*c217d954SCole Faust if(has_layer_norm)
973*c217d954SCole Faust {
974*c217d954SCole Faust const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
975*c217d954SCole Faust const ITensorInfo *b_info = output_gate_bias;
976*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
977*c217d954SCole Faust }
978*c217d954SCole Faust
979*c217d954SCole Faust const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
980*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
981*c217d954SCole Faust
982*c217d954SCole Faust // Hidden.
983*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, &input_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
984*c217d954SCole Faust const TensorInfo hidden_mul_res(TensorShape(num_units, batch_size), 1, DataType::S32);
985*c217d954SCole Faust const TensorInfo hidden_out_info(TensorShape(num_units, batch_size), 1, DataType::QASYMM8_SIGNED);
986*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_gate_info, &input_gate_info, &hidden_mul_res, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
987*c217d954SCole Faust
988*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.hidden_state_scale() == 0);
989*c217d954SCole Faust const float hidden_state_scale = std::pow(2, -15) / lstm_params.hidden_state_scale() * std::pow(2, -15);
990*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(hidden_state_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift, /* ignore_epsilon */ true));
991*c217d954SCole Faust gemmlowp_info.gemmlowp_offset = lstm_params.hidden_state_zero();
992*c217d954SCole Faust gemmlowp_info.output_data_type = hidden_out_info.data_type();
993*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&hidden_mul_res, nullptr, &hidden_out_info, gemmlowp_info));
994*c217d954SCole Faust
995*c217d954SCole Faust const bool projection_tensor_copy_required = num_units != output_size;
996*c217d954SCole Faust
997*c217d954SCole Faust // Projection.
998*c217d954SCole Faust if(lstm_params.has_projection())
999*c217d954SCole Faust {
1000*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_forget_weights, lstm_params.projection_weights());
1001*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(qoutput_state_in.scale == 0);
1002*c217d954SCole Faust
1003*c217d954SCole Faust const UniformQuantizationInfo qprojection = lstm_params.projection_weights()->quantization_info().uniform();
1004*c217d954SCole Faust const float projection_scale = qprojection.scale * lstm_params.hidden_state_scale() / qoutput_state_in.scale;
1005*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(projection_scale, &gemmlowp_info.gemmlowp_multiplier, &gemmlowp_info.gemmlowp_shift));
1006*c217d954SCole Faust gemmlowp_info.gemmlowp_offset = qoutput_state_in.offset;
1007*c217d954SCole Faust gemmlowp_info.gemmlowp_min_bound = std::numeric_limits<int8_t>::lowest();
1008*c217d954SCole Faust gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int8_t>::max();
1009*c217d954SCole Faust gemmlowp_info.output_data_type = DataType::QASYMM8_SIGNED;
1010*c217d954SCole Faust
1011*c217d954SCole Faust const TensorInfo projection_outstage_info(*output_state_out);
1012*c217d954SCole Faust const TensorInfo projection_weights_transposed(TensorShape(output_size, num_units), 1, lstm_params.projection_weights()->data_type(), lstm_params.projection_weights()->quantization_info());
1013*c217d954SCole Faust
1014*c217d954SCole Faust TensorInfo projection_mm_out_info{ mm_out_info };
1015*c217d954SCole Faust projection_mm_out_info.set_tensor_shape(TensorShape(output_size, batch_size));
1016*c217d954SCole Faust
1017*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemmlowp_info, &hidden_out_info, &projection_weights_transposed, &projection_eff_bias_info, projection_scale, &projection_mm_out_info,
1018*c217d954SCole Faust &projection_outstage_info));
1019*c217d954SCole Faust
1020*c217d954SCole Faust if(projection_tensor_copy_required)
1021*c217d954SCole Faust {
1022*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info));
1023*c217d954SCole Faust }
1024*c217d954SCole Faust
1025*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE));
1026*c217d954SCole Faust
1027*c217d954SCole Faust if(projection_tensor_copy_required)
1028*c217d954SCole Faust {
1029*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(projection_outstage_info, *output_state_out));
1030*c217d954SCole Faust }
1031*c217d954SCole Faust
1032*c217d954SCole Faust int8_t quantized_projection_clip{ 0 };
1033*c217d954SCole Faust if(lstm_params.projection_clip() > 0.0f)
1034*c217d954SCole Faust {
1035*c217d954SCole Faust quantized_projection_clip = quantize_qasymm8_signed(lstm_params.projection_clip(), qprojection);
1036*c217d954SCole Faust }
1037*c217d954SCole Faust
1038*c217d954SCole Faust if(quantized_projection_clip > 0)
1039*c217d954SCole Faust {
1040*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -quantized_projection_clip,
1041*c217d954SCole Faust quantized_projection_clip)));
1042*c217d954SCole Faust }
1043*c217d954SCole Faust }
1044*c217d954SCole Faust else
1045*c217d954SCole Faust {
1046*c217d954SCole Faust if(projection_tensor_copy_required)
1047*c217d954SCole Faust {
1048*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(hidden_out_info, *output_state_out));
1049*c217d954SCole Faust }
1050*c217d954SCole Faust }
1051*c217d954SCole Faust
1052*c217d954SCole Faust if(cell_state_out->total_size() > 0)
1053*c217d954SCole Faust {
1054*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(cell_state_in, cell_state_out);
1055*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(cell_state_in, cell_state_out);
1056*c217d954SCole Faust }
1057*c217d954SCole Faust
1058*c217d954SCole Faust if(output_state_out->total_size() > 0)
1059*c217d954SCole Faust {
1060*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output_state_out);
1061*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
1062*c217d954SCole Faust }
1063*c217d954SCole Faust
1064*c217d954SCole Faust ARM_COMPUTE_RETURN_ON_ERROR(NECopy::validate(output_state_out, output));
1065*c217d954SCole Faust return Status{};
1066*c217d954SCole Faust }
1067*c217d954SCole Faust
run()1068*c217d954SCole Faust void NEQLSTMLayer::run()
1069*c217d954SCole Faust {
1070*c217d954SCole Faust prepare();
1071*c217d954SCole Faust
1072*c217d954SCole Faust // Acquire all the temporaries
1073*c217d954SCole Faust MemoryGroupResourceScope scope_mg(_memory_group);
1074*c217d954SCole Faust
1075*c217d954SCole Faust // Forget gate.
1076*c217d954SCole Faust _mm_input_to_forget.run();
1077*c217d954SCole Faust _input_to_forget_outstage.run();
1078*c217d954SCole Faust
1079*c217d954SCole Faust _mm_recurrent_to_forget.run();
1080*c217d954SCole Faust _recurrent_to_forget_outstage.run();
1081*c217d954SCole Faust _accumulate_input_recurrent_forget.run();
1082*c217d954SCole Faust
1083*c217d954SCole Faust if(_has_peephole)
1084*c217d954SCole Faust {
1085*c217d954SCole Faust _pixelwise_mul_cell_to_forget.run();
1086*c217d954SCole Faust _cell_to_forget_outstage.run();
1087*c217d954SCole Faust _accumulate_cell_forget.run();
1088*c217d954SCole Faust }
1089*c217d954SCole Faust
1090*c217d954SCole Faust if(_has_layer_norm)
1091*c217d954SCole Faust {
1092*c217d954SCole Faust NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Forget).get(), Window::DimY);
1093*c217d954SCole Faust }
1094*c217d954SCole Faust
1095*c217d954SCole Faust _forget_gate_sigmoid.run();
1096*c217d954SCole Faust
1097*c217d954SCole Faust // Modulation gate.
1098*c217d954SCole Faust _mm_input_to_cell.run();
1099*c217d954SCole Faust _input_to_cell_outstage.run();
1100*c217d954SCole Faust
1101*c217d954SCole Faust _mm_recurrent_to_cell.run();
1102*c217d954SCole Faust _recurrent_to_cell_outstage.run();
1103*c217d954SCole Faust _accumulate_input_recurrent_modulation.run();
1104*c217d954SCole Faust
1105*c217d954SCole Faust if(_has_layer_norm)
1106*c217d954SCole Faust {
1107*c217d954SCole Faust NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Cell).get(), Window::DimY);
1108*c217d954SCole Faust }
1109*c217d954SCole Faust
1110*c217d954SCole Faust _cell_gate_tanh.run();
1111*c217d954SCole Faust
1112*c217d954SCole Faust // Input gate
1113*c217d954SCole Faust if(_has_cifg)
1114*c217d954SCole Faust {
1115*c217d954SCole Faust _input_gate_sub.run();
1116*c217d954SCole Faust }
1117*c217d954SCole Faust else
1118*c217d954SCole Faust {
1119*c217d954SCole Faust _mm_input_to_input.run();
1120*c217d954SCole Faust _input_to_input_outstage.run();
1121*c217d954SCole Faust _mm_recurrent_to_input.run();
1122*c217d954SCole Faust _recurrent_to_input_outstage.run();
1123*c217d954SCole Faust _accumulate_input_recurrent_input.run();
1124*c217d954SCole Faust
1125*c217d954SCole Faust if(_has_peephole)
1126*c217d954SCole Faust {
1127*c217d954SCole Faust _pixelwise_mul_cell_to_input.run();
1128*c217d954SCole Faust _cell_to_input_outstage.run();
1129*c217d954SCole Faust _accumulate_cell_input.run();
1130*c217d954SCole Faust }
1131*c217d954SCole Faust
1132*c217d954SCole Faust if(_has_layer_norm)
1133*c217d954SCole Faust {
1134*c217d954SCole Faust NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Input).get(), Window::DimY);
1135*c217d954SCole Faust }
1136*c217d954SCole Faust
1137*c217d954SCole Faust _input_gate_sigmoid.run();
1138*c217d954SCole Faust }
1139*c217d954SCole Faust
1140*c217d954SCole Faust // Cell.
1141*c217d954SCole Faust _pixelwise_mul_forget_cell.run();
1142*c217d954SCole Faust _pixelwise_mul_input_cell.run();
1143*c217d954SCole Faust _add_forget_cell.run();
1144*c217d954SCole Faust
1145*c217d954SCole Faust if(_has_cell_clipping)
1146*c217d954SCole Faust {
1147*c217d954SCole Faust _cell_clip.run();
1148*c217d954SCole Faust }
1149*c217d954SCole Faust
1150*c217d954SCole Faust // Output gate.
1151*c217d954SCole Faust _mm_input_to_output.run();
1152*c217d954SCole Faust _input_to_output_outstage.run();
1153*c217d954SCole Faust _mm_recurrent_to_output.run();
1154*c217d954SCole Faust _recurrent_to_output_outstage.run();
1155*c217d954SCole Faust _accumulate_input_recurrent_output.run();
1156*c217d954SCole Faust if(_has_peephole)
1157*c217d954SCole Faust {
1158*c217d954SCole Faust _pixelwise_mul_cell_to_output.run();
1159*c217d954SCole Faust _cell_to_output_outstage.run();
1160*c217d954SCole Faust _accumulate_cell_to_output.run();
1161*c217d954SCole Faust }
1162*c217d954SCole Faust
1163*c217d954SCole Faust if(_has_layer_norm)
1164*c217d954SCole Faust {
1165*c217d954SCole Faust NEScheduler::get().schedule(get_layer_norm(LayerNormGate::Output).get(), Window::DimY);
1166*c217d954SCole Faust }
1167*c217d954SCole Faust
1168*c217d954SCole Faust _output_gate_sigmoid.run();
1169*c217d954SCole Faust
1170*c217d954SCole Faust // Hidden.
1171*c217d954SCole Faust _hidden_tanh.run();
1172*c217d954SCole Faust _pixelwise_mul_hidden.run();
1173*c217d954SCole Faust _hidden_outstage.run();
1174*c217d954SCole Faust
1175*c217d954SCole Faust // Projection.
1176*c217d954SCole Faust if(_has_projection)
1177*c217d954SCole Faust {
1178*c217d954SCole Faust _mm_projection.run();
1179*c217d954SCole Faust _projection_outstage.run();
1180*c217d954SCole Faust
1181*c217d954SCole Faust if(_projection_tensor_copy_required)
1182*c217d954SCole Faust {
1183*c217d954SCole Faust _projection_output_to_accumulate_copy.run();
1184*c217d954SCole Faust }
1185*c217d954SCole Faust
1186*c217d954SCole Faust _accumulate_projection.run();
1187*c217d954SCole Faust
1188*c217d954SCole Faust if(_projection_tensor_copy_required)
1189*c217d954SCole Faust {
1190*c217d954SCole Faust _projection_accumulate_to_output_copy.run();
1191*c217d954SCole Faust }
1192*c217d954SCole Faust
1193*c217d954SCole Faust if(_has_projection_clipping)
1194*c217d954SCole Faust {
1195*c217d954SCole Faust _projection_clip.run();
1196*c217d954SCole Faust }
1197*c217d954SCole Faust }
1198*c217d954SCole Faust else
1199*c217d954SCole Faust {
1200*c217d954SCole Faust if(_projection_tensor_copy_required)
1201*c217d954SCole Faust {
1202*c217d954SCole Faust _hidden_to_output_copy.run();
1203*c217d954SCole Faust }
1204*c217d954SCole Faust }
1205*c217d954SCole Faust
1206*c217d954SCole Faust // Copy output_state_out to output
1207*c217d954SCole Faust _copy_output.run();
1208*c217d954SCole Faust }
1209*c217d954SCole Faust
prepare()1210*c217d954SCole Faust void NEQLSTMLayer::prepare()
1211*c217d954SCole Faust {
1212*c217d954SCole Faust if(!_is_prepared)
1213*c217d954SCole Faust {
1214*c217d954SCole Faust if(_convert_input_to_forget_weights_to_qsymm8)
1215*c217d954SCole Faust {
1216*c217d954SCole Faust _input_to_forget_weights_f32.allocator()->allocate();
1217*c217d954SCole Faust _input_to_forget_weights_symm8.allocator()->allocate();
1218*c217d954SCole Faust _dequantize_input_to_forget_weights.run();
1219*c217d954SCole Faust _quantize_input_to_forget_weights.run();
1220*c217d954SCole Faust }
1221*c217d954SCole Faust
1222*c217d954SCole Faust // Pre-transpose weights to be used in GEMM.
1223*c217d954SCole Faust _input_to_forget_weights_transposed.allocator()->allocate();
1224*c217d954SCole Faust _input_to_cell_weights_transposed.allocator()->allocate();
1225*c217d954SCole Faust _input_to_output_weights_transposed.allocator()->allocate();
1226*c217d954SCole Faust _recurrent_to_forget_weights_transposed.allocator()->allocate();
1227*c217d954SCole Faust _recurrent_to_cell_weights_transposed.allocator()->allocate();
1228*c217d954SCole Faust _recurrent_to_output_weights_transposed.allocator()->allocate();
1229*c217d954SCole Faust _transpose_input_to_forget_weights.run();
1230*c217d954SCole Faust _transpose_input_to_cell_weights.run();
1231*c217d954SCole Faust _transpose_input_to_output_weights.run();
1232*c217d954SCole Faust _transpose_recurrent_to_forget_weights.run();
1233*c217d954SCole Faust _transpose_recurrent_to_cell_weights.run();
1234*c217d954SCole Faust _transpose_recurrent_to_output_weights.run();
1235*c217d954SCole Faust
1236*c217d954SCole Faust // Precompute effective biases
1237*c217d954SCole Faust if(_has_cifg)
1238*c217d954SCole Faust {
1239*c217d954SCole Faust std::fill_n(reinterpret_cast<int16_t *>(_ones.buffer()), _ones.info()->total_size() / _ones.info()->element_size(), 32767);
1240*c217d954SCole Faust }
1241*c217d954SCole Faust else
1242*c217d954SCole Faust {
1243*c217d954SCole Faust _input_to_input_eff_bias.allocator()->allocate();
1244*c217d954SCole Faust _recurrent_to_input_eff_bias.allocator()->allocate();
1245*c217d954SCole Faust
1246*c217d954SCole Faust ITensorPack packII =
1247*c217d954SCole Faust {
1248*c217d954SCole Faust { TensorType::ACL_SRC, _input_to_input_weights },
1249*c217d954SCole Faust { TensorType::ACL_DST, &_input_to_input_eff_bias }
1250*c217d954SCole Faust };
1251*c217d954SCole Faust NEScheduler::get().schedule_op(_input_to_input_reduction.get(), Window::DimY, _input_to_input_reduction->window(), packII);
1252*c217d954SCole Faust
1253*c217d954SCole Faust ITensorPack packRI =
1254*c217d954SCole Faust {
1255*c217d954SCole Faust { TensorType::ACL_SRC, _recurrent_to_input_weights },
1256*c217d954SCole Faust { TensorType::ACL_DST, &_recurrent_to_input_eff_bias }
1257*c217d954SCole Faust };
1258*c217d954SCole Faust NEScheduler::get().schedule_op(_recurrent_to_input_reduction.get(), Window::DimY, _recurrent_to_input_reduction->window(), packRI);
1259*c217d954SCole Faust
1260*c217d954SCole Faust _input_to_input_weights_transposed.allocator()->allocate();
1261*c217d954SCole Faust _recurrent_to_input_weights_transposed.allocator()->allocate();
1262*c217d954SCole Faust _transpose_input_to_input_weights.run();
1263*c217d954SCole Faust _transpose_recurrent_to_input_weights.run();
1264*c217d954SCole Faust _input_to_input_weights->mark_as_unused();
1265*c217d954SCole Faust _recurrent_to_input_weights->mark_as_unused();
1266*c217d954SCole Faust }
1267*c217d954SCole Faust _input_to_forget_eff_bias.allocator()->allocate();
1268*c217d954SCole Faust _recurrent_to_forget_eff_bias.allocator()->allocate();
1269*c217d954SCole Faust _input_to_cell_eff_bias.allocator()->allocate();
1270*c217d954SCole Faust _recurrent_to_cell_eff_bias.allocator()->allocate();
1271*c217d954SCole Faust _input_to_output_eff_bias.allocator()->allocate();
1272*c217d954SCole Faust _recurrent_to_output_eff_bias.allocator()->allocate();
1273*c217d954SCole Faust
1274*c217d954SCole Faust ITensorPack packIF =
1275*c217d954SCole Faust {
1276*c217d954SCole Faust { TensorType::ACL_SRC, _input_to_forget_weights },
1277*c217d954SCole Faust { TensorType::ACL_DST, &_input_to_forget_eff_bias }
1278*c217d954SCole Faust };
1279*c217d954SCole Faust NEScheduler::get().schedule_op(_input_to_forget_reduction.get(), Window::DimY, _input_to_forget_reduction->window(), packIF);
1280*c217d954SCole Faust
1281*c217d954SCole Faust ITensorPack packRF =
1282*c217d954SCole Faust {
1283*c217d954SCole Faust { TensorType::ACL_SRC, _recurrent_to_forget_weights },
1284*c217d954SCole Faust { TensorType::ACL_DST, &_recurrent_to_forget_eff_bias }
1285*c217d954SCole Faust };
1286*c217d954SCole Faust NEScheduler::get().schedule_op(_recurrent_to_forget_reduction.get(), Window::DimY, _recurrent_to_forget_reduction->window(), packRF);
1287*c217d954SCole Faust
1288*c217d954SCole Faust ITensorPack packIC =
1289*c217d954SCole Faust {
1290*c217d954SCole Faust { TensorType::ACL_SRC, _input_to_cell_weights },
1291*c217d954SCole Faust { TensorType::ACL_DST, &_input_to_cell_eff_bias }
1292*c217d954SCole Faust };
1293*c217d954SCole Faust NEScheduler::get().schedule_op(_input_to_cell_reduction.get(), Window::DimY, _input_to_cell_reduction->window(), packIC);
1294*c217d954SCole Faust
1295*c217d954SCole Faust ITensorPack packRC =
1296*c217d954SCole Faust {
1297*c217d954SCole Faust { TensorType::ACL_SRC, _recurrent_to_cell_weights },
1298*c217d954SCole Faust { TensorType::ACL_DST, &_recurrent_to_cell_eff_bias }
1299*c217d954SCole Faust };
1300*c217d954SCole Faust NEScheduler::get().schedule_op(_recurrent_to_cell_reduction.get(), Window::DimY, _recurrent_to_cell_reduction->window(), packRC);
1301*c217d954SCole Faust
1302*c217d954SCole Faust ITensorPack packIO =
1303*c217d954SCole Faust {
1304*c217d954SCole Faust { TensorType::ACL_SRC, _input_to_output_weights },
1305*c217d954SCole Faust { TensorType::ACL_DST, &_input_to_output_eff_bias }
1306*c217d954SCole Faust };
1307*c217d954SCole Faust NEScheduler::get().schedule_op(_input_to_output_reduction.get(), Window::DimY, _input_to_output_reduction->window(), packIO);
1308*c217d954SCole Faust
1309*c217d954SCole Faust ITensorPack packRO =
1310*c217d954SCole Faust {
1311*c217d954SCole Faust { TensorType::ACL_SRC, _recurrent_to_output_weights },
1312*c217d954SCole Faust { TensorType::ACL_DST, &_recurrent_to_output_eff_bias }
1313*c217d954SCole Faust };
1314*c217d954SCole Faust NEScheduler::get().schedule_op(_recurrent_to_output_reduction.get(), Window::DimY, _recurrent_to_output_reduction->window(), packRO);
1315*c217d954SCole Faust
1316*c217d954SCole Faust if(_has_projection)
1317*c217d954SCole Faust {
1318*c217d954SCole Faust _projection_eff_bias.allocator()->allocate();
1319*c217d954SCole Faust ITensorPack pack =
1320*c217d954SCole Faust {
1321*c217d954SCole Faust { TensorType::ACL_SRC, _projection_weights },
1322*c217d954SCole Faust { TensorType::ACL_DST, &_projection_eff_bias }
1323*c217d954SCole Faust };
1324*c217d954SCole Faust NEScheduler::get().schedule_op(_projection_reduction.get(), Window::DimY, _projection_reduction->window(), pack);
1325*c217d954SCole Faust if(_projection_bias != nullptr)
1326*c217d954SCole Faust {
1327*c217d954SCole Faust _projection_bias_add.run();
1328*c217d954SCole Faust _projection_bias->mark_as_unused();
1329*c217d954SCole Faust }
1330*c217d954SCole Faust
1331*c217d954SCole Faust _projection_weights_transposed.allocator()->allocate();
1332*c217d954SCole Faust _transpose_projection_weights.run();
1333*c217d954SCole Faust _projection_weights->mark_as_unused();
1334*c217d954SCole Faust
1335*c217d954SCole Faust if(!_projection_tensor_copy_required)
1336*c217d954SCole Faust {
1337*c217d954SCole Faust _hidden_gate.mark_as_unused();
1338*c217d954SCole Faust _projection_accumulate_res.mark_as_unused();
1339*c217d954SCole Faust }
1340*c217d954SCole Faust }
1341*c217d954SCole Faust
1342*c217d954SCole Faust // Mark weights as unused
1343*c217d954SCole Faust _input_to_forget_weights->mark_as_unused();
1344*c217d954SCole Faust _input_to_cell_weights->mark_as_unused();
1345*c217d954SCole Faust _input_to_output_weights->mark_as_unused();
1346*c217d954SCole Faust _recurrent_to_forget_weights->mark_as_unused();
1347*c217d954SCole Faust _recurrent_to_cell_weights->mark_as_unused();
1348*c217d954SCole Faust _recurrent_to_output_weights->mark_as_unused();
1349*c217d954SCole Faust
1350*c217d954SCole Faust _is_prepared = true;
1351*c217d954SCole Faust }
1352*c217d954SCole Faust }
1353*c217d954SCole Faust } // namespace arm_compute
1354