xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/NEON/functions/NELSTMLayerQuantized.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/NEON/functions/NELSTMLayerQuantized.h"
25 
26 #include "arm_compute/core/Utils.h"
27 #include "arm_compute/core/Validate.h"
28 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
29 #include "src/common/utils/Log.h"
30 #include "src/core/helpers/AutoConfiguration.h"
31 
32 #include <cmath>
33 #include <memory>
34 #include <tuple>
35 
36 namespace arm_compute
37 {
38 namespace
39 {
40 // Quantization info structures used in the LSTMQuantize layer
41 const QuantizationInfo qasymm(1.f / 128.f, 128);
42 const QuantizationInfo qsymm_3(8.f / 32768.f, 0);  // qsymm16 with 3 integer bit
43 const QuantizationInfo qsymm_4(16.f / 32768.f, 0); // qsymm16 with 4 integer bit
44 const QuantizationInfo qsymm_0(1.f / 32768.f, 0);  // qsymm16 with 0 integer bit
45 } // namespace
46 NELSTMLayerQuantized::~NELSTMLayerQuantized() = default;
47 
NELSTMLayerQuantized(std::shared_ptr<IMemoryManager> memory_manager)48 NELSTMLayerQuantized::NELSTMLayerQuantized(std::shared_ptr<IMemoryManager> memory_manager)
49     : _memory_group(std::move(memory_manager)), _gemmlowp(), _output_stage(), _transpose_weights(), _concat_input_weights(), _concat_recurrent_weights(), _concat_weights(), _concat_inputs(),
50       _concat_bias(), _sigmoid_forget_gate(), _sigmoid_input_gate(), _sigmoid_output_gate(), _tanh_modulation_gate(), _tanh_output_state(), _add1(), _add2(), _mul1(), _mul2(), _mul3(),
51       _slice_input_tensor(), _slice_forget_tensor(), _slice_cell_tensor(), _slice_output_tensor(), _dequantize(), _quantize(), _input_to_input_weights(nullptr), _input_to_forget_weights(nullptr),
52       _input_to_cell_weights(nullptr), _input_to_output_weights(nullptr), _recurrent_to_input_weights(nullptr), _recurrent_to_forget_weights(nullptr), _recurrent_to_cell_weights(nullptr),
53       _recurrent_to_output_weights(nullptr), _input_gate_bias(nullptr), _forget_gate_bias(nullptr), _cell_bias(nullptr), _output_gate_bias(nullptr), _recurrent_weights(), _input_weights(), _weights(),
54       _input(), _weights_transposed(), _output_highp(), _output_lowp(), _bias(), _forget_gate_input(), _input_gate_input(), _output_gate_input(), _input_modulation_gate_input(), _forget_gate_output(),
55       _input_gate_output(), _output_gate_output(), _input_modulation_gate_output(), _cell_state1(), _cell_state2(), _output_state_tmp(), _output_state_out_symm(), _output_state_out_f32(),
56       _is_prepared(false)
57 {
58 }
59 
configure(const ITensor * input,const ITensor * input_to_input_weights,const ITensor * input_to_forget_weights,const ITensor * input_to_cell_weights,const ITensor * input_to_output_weights,const ITensor * recurrent_to_input_weights,const ITensor * recurrent_to_forget_weights,const ITensor * recurrent_to_cell_weights,const ITensor * recurrent_to_output_weights,const ITensor * input_gate_bias,const ITensor * forget_gate_bias,const ITensor * cell_bias,const ITensor * output_gate_bias,ITensor * cell_state_in,const ITensor * output_state_in,ITensor * cell_state_out,ITensor * output_state_out)60 void NELSTMLayerQuantized::configure(const ITensor *input,
61                                      const ITensor *input_to_input_weights, const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
62                                      const ITensor *recurrent_to_input_weights, const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
63                                      const ITensor *input_gate_bias, const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
64                                      ITensor *cell_state_in, const ITensor *output_state_in,
65                                      ITensor *cell_state_out, ITensor *output_state_out)
66 {
67     ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
68                                  recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
69                                  input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
70 
71     ARM_COMPUTE_ERROR_THROW_ON(NELSTMLayerQuantized::validate(input->info(), input_to_input_weights->info(), input_to_forget_weights->info(), input_to_cell_weights->info(),
72                                                               input_to_output_weights->info(),
73                                                               recurrent_to_input_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
74                                                               input_gate_bias->info(), forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info()));
75 
76     ARM_COMPUTE_LOG_PARAMS(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
77                            recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
78                            input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
79 
80     const int input_size  = input->info()->dimension(0);
81     const int batch_size  = input->info()->dimension(1);
82     const int output_size = input_to_input_weights->info()->dimension(1);
83 
84     const QuantizationInfo qweights = input_to_input_weights->info()->quantization_info(); // Weights quantization
85 
86     auto_init_if_empty(*cell_state_out->info(), TensorInfo(TensorShape(batch_size, output_size), 1, DataType::QSYMM16, qsymm_4));
87     auto_init_if_empty(*output_state_out->info(), TensorInfo(TensorShape(batch_size, output_size), 1, DataType::QASYMM8, qasymm));
88 
89     _input_to_input_weights      = input_to_input_weights;
90     _input_to_forget_weights     = input_to_forget_weights;
91     _input_to_cell_weights       = input_to_cell_weights;
92     _input_to_output_weights     = input_to_output_weights;
93     _recurrent_to_input_weights  = recurrent_to_input_weights;
94     _recurrent_to_forget_weights = recurrent_to_forget_weights;
95     _recurrent_to_cell_weights   = recurrent_to_cell_weights;
96     _recurrent_to_output_weights = recurrent_to_output_weights;
97     _input_gate_bias             = input_gate_bias;
98     _forget_gate_bias            = forget_gate_bias;
99     _cell_bias                   = cell_bias;
100     _output_gate_bias            = output_gate_bias;
101 
102     // Weights concatenation
103     std::vector<const ITensor *> inputs_weights_vector{ input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights };
104     std::vector<const ITensor *> recurrent_weights_vector{ recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights };
105 
106     _input_weights.allocator()->init(TensorInfo(TensorShape(input_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
107     _concat_input_weights.configure(inputs_weights_vector, &_input_weights, Window::DimY);
108 
109     _recurrent_weights.allocator()->init(TensorInfo(TensorShape(output_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
110     _concat_recurrent_weights.configure(recurrent_weights_vector, &_recurrent_weights, Window::DimY);
111 
112     std::vector<const ITensor *> weights_vector{ &_recurrent_weights, &_input_weights };
113     _weights.allocator()->init(TensorInfo(TensorShape(output_size + input_size, 4 * output_size), 1, DataType::QASYMM8, qweights));
114     _concat_weights.configure(weights_vector, &_weights, Window::DimX);
115     _transpose_weights.configure(&_weights, &_weights_transposed);
116 
117     // Input concatenation
118     std::vector<const ITensor *> input_vector{ input, output_state_in };
119     _memory_group.manage(&_input);
120     _input.allocator()->init(TensorInfo(TensorShape(output_size + input_size, batch_size), 1, DataType::QASYMM8, qasymm));
121     _concat_inputs.configure(input_vector, &_input, Window::DimX);
122 
123     // Bias concatenation
124     std::vector<const ITensor *> bias_vector{ input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias };
125     _bias.allocator()->init(TensorInfo(TensorShape(4 * output_size), 1, DataType::S32));
126     _concat_bias.configure(bias_vector, &_bias, Window::DimX);
127 
128     // Invert the offset for gemmlowp
129     _input.info()->set_quantization_info(QuantizationInfo(qasymm.uniform().scale, -qasymm.uniform().offset));
130     _weights_transposed.info()->set_quantization_info(QuantizationInfo(qweights.uniform().scale, -qweights.uniform().offset));
131 
132     // Run gemmlowp
133     _memory_group.manage(&_output_highp);
134     _output_highp.allocator()->init(TensorInfo(TensorShape(4 * output_size, batch_size), 1, DataType::S32));
135     _gemmlowp.configure(&_input, &_weights_transposed, nullptr, &_output_highp);
136     _input.allocator()->allocate();
137 
138     // Set the offset back
139     _input.info()->set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
140     _weights_transposed.info()->set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
141 
142     // multiplier = (input_scale * weights_scale) / output_scale (2 ^ (-12))
143     _output_lowp.allocator()->init(TensorInfo(_output_highp.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_3));
144 
145     const float multiplier        = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
146     int32_t     output_multiplier = 0;
147     int32_t     output_shift      = 0;
148     quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
149 
150     _memory_group.manage(&_output_lowp);
151 
152     GEMMLowpOutputStageInfo info;
153     info.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
154     info.gemmlowp_multiplier = output_multiplier;
155     info.gemmlowp_shift      = output_shift;
156     info.output_data_type    = DataType::QSYMM16;
157     _output_stage.configure(&_output_highp, &_bias, &_output_lowp, info);
158     _output_highp.allocator()->allocate();
159     _bias.allocator()->allocate();
160 
161     // Get the gate tensors
162     if(batch_size > 1)
163     {
164         _memory_group.manage(&_input_gate_input);
165         _slice_input_tensor.configure(&_output_lowp, &_input_gate_input, { 0, 0 }, { output_size, batch_size });
166         _memory_group.manage(&_forget_gate_input);
167         _slice_forget_tensor.configure(&_output_lowp, &_forget_gate_input, { output_size, 0 }, { 2 * output_size, batch_size });
168         _memory_group.manage(&_input_modulation_gate_input);
169         _slice_cell_tensor.configure(&_output_lowp, &_input_modulation_gate_input, { 2 * output_size, 0 }, { 3 * output_size, batch_size });
170         _memory_group.manage(&_output_gate_input);
171         _slice_output_tensor.configure(&_output_lowp, &_output_gate_input, { 3 * output_size, 0 }, { 4 * output_size, batch_size });
172         _output_lowp.allocator()->allocate();
173     }
174     else
175     {
176         _memory_group.manage(&_input_gate_input);
177         _slice_input_tensor.configure(&_output_lowp, &_input_gate_input, { 0 }, { output_size });
178         _memory_group.manage(&_forget_gate_input);
179         _slice_forget_tensor.configure(&_output_lowp, &_forget_gate_input, { output_size }, { 2 * output_size });
180         _memory_group.manage(&_input_modulation_gate_input);
181         _slice_cell_tensor.configure(&_output_lowp, &_input_modulation_gate_input, { 2 * output_size }, { 3 * output_size });
182         _memory_group.manage(&_output_gate_input);
183         _slice_output_tensor.configure(&_output_lowp, &_output_gate_input, { 3 * output_size }, { 4 * output_size });
184         _output_lowp.allocator()->allocate();
185     }
186 
187     // Forget gate
188     _memory_group.manage(&_forget_gate_output);
189     _forget_gate_output.allocator()->init(TensorInfo(_forget_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
190     _sigmoid_forget_gate.configure(&_forget_gate_input, &_forget_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
191     _forget_gate_input.allocator()->allocate();
192 
193     // Input gate
194     _memory_group.manage(&_input_gate_output);
195     _input_gate_output.allocator()->init(TensorInfo(_input_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
196     _sigmoid_input_gate.configure(&_input_gate_input, &_input_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
197     _input_gate_input.allocator()->allocate();
198 
199     // Input modulation gate equation
200     _memory_group.manage(&_input_modulation_gate_output);
201     _input_modulation_gate_output.allocator()->init(TensorInfo(_input_modulation_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
202     _tanh_modulation_gate.configure(&_input_modulation_gate_input, &_input_modulation_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f));
203     _input_modulation_gate_input.allocator()->allocate();
204 
205     // Output gate
206     _memory_group.manage(&_output_gate_output);
207     _output_gate_output.allocator()->init(TensorInfo(_output_gate_input.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
208     _sigmoid_output_gate.configure(&_output_gate_input, &_output_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
209     _output_gate_input.allocator()->allocate();
210 
211     // Long term memory
212     _memory_group.manage(&_cell_state1);
213     _cell_state1.allocator()->init(TensorInfo(_forget_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_4));
214     _mul1.configure(&_forget_gate_output, cell_state_in, &_cell_state1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
215     _forget_gate_output.allocator()->allocate();
216 
217     _memory_group.manage(&_cell_state2);
218     _cell_state2.allocator()->init(TensorInfo(_input_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_4));
219     _mul2.configure(&_input_gate_output, &_input_modulation_gate_output, &_cell_state2, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
220     _input_modulation_gate_output.allocator()->allocate();
221     _input_gate_output.allocator()->allocate();
222 
223     _add1.configure(&_cell_state1, &_cell_state2, cell_state_out, ConvertPolicy::SATURATE);
224     _cell_state1.allocator()->allocate();
225     _cell_state2.allocator()->allocate();
226 
227     // Short term memory
228     _memory_group.manage(&_output_state_tmp);
229     _output_state_tmp.allocator()->init(TensorInfo(cell_state_out->info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
230     _tanh_output_state.configure(cell_state_out, &_output_state_tmp, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f));
231 
232     _memory_group.manage(&_output_state_out_symm);
233     _output_state_out_symm.allocator()->init(TensorInfo(_output_gate_output.info()->tensor_shape(), 1, DataType::QSYMM16, qsymm_0));
234     _mul3.configure(&_output_state_tmp, &_output_gate_output, &_output_state_out_symm, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO);
235     _output_gate_output.allocator()->allocate();
236     _output_state_tmp.allocator()->allocate();
237 
238     // Requantize the output state from QSYMM16 to QASYMM8
239     _memory_group.manage(&_output_state_out_f32);
240     _output_state_out_f32.allocator()->init(TensorInfo(_output_state_out_symm.info()->tensor_shape(), 1, DataType::F32));
241     _dequantize.configure(&_output_state_out_symm, &_output_state_out_f32);
242     _output_state_out_symm.allocator()->allocate();
243 
244     _quantize.configure(&_output_state_out_f32, output_state_out);
245     _output_state_out_f32.allocator()->allocate();
246 }
247 
validate(const ITensorInfo * input,const ITensorInfo * input_to_input_weights,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_input_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * input_gate_bias,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * cell_state_in,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_out,const ITensorInfo * output_state_out)248 Status NELSTMLayerQuantized::validate(const ITensorInfo *input,
249                                       const ITensorInfo *input_to_input_weights, const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
250                                       const ITensorInfo *recurrent_to_input_weights, const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
251                                       const ITensorInfo *input_gate_bias, const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
252                                       const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
253                                       const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out)
254 {
255     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights,
256                                         recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in,
257                                         output_state_in, cell_state_out, output_state_out);
258 
259     const int input_size  = input->dimension(0);
260     const int batch_size  = input->dimension(1);
261     const int output_size = input_to_input_weights->dimension(1);
262 
263     // Dimensionality checks
264     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
265     ARM_COMPUTE_RETURN_ERROR_ON(input_to_input_weights->num_dimensions() > 2);
266     ARM_COMPUTE_RETURN_ERROR_ON(input_gate_bias->num_dimensions() > 1);
267     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
268 
269     TensorInfo input_weights_info(input_to_input_weights->clone()->set_tensor_shape(TensorShape(input_size, output_size)).set_data_type(DataType::QASYMM8));
270     TensorInfo recurrent_weights_info(input_to_input_weights->clone()->set_tensor_shape(TensorShape(output_size, output_size)).set_data_type(DataType::QASYMM8));
271     TensorInfo bias_info(input_gate_bias->clone()->set_tensor_shape(TensorShape(output_size)).set_data_type(DataType::S32));
272     TensorInfo output_state_info(cell_state_in->clone()->set_tensor_shape(TensorShape(output_size, batch_size)).set_data_type(DataType::QASYMM8).set_quantization_info(qasymm));
273     TensorInfo cell_state_info(cell_state_in->clone()->set_tensor_shape(TensorShape(output_size, batch_size)).set_data_type(DataType::QSYMM16).set_quantization_info(qsymm_4));
274 
275     // Shape checks
276     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input_weights_info, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights);
277     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&recurrent_weights_info, recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
278     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&bias_info, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias);
279     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&cell_state_info, cell_state_in);
280     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&output_state_info, output_state_in);
281 
282     // Data type checks
283     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input_weights_info, input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights);
284     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
285     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&bias_info, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias);
286     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&cell_state_info, cell_state_in);
287     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&output_state_info, output_state_in);
288 
289     // Quantization checks
290     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input_weights_info, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights);
291     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights);
292     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&cell_state_info, cell_state_in);
293     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&output_state_info, output_state_in);
294 
295     // Validate internal functions
296     // _concat_input_weights
297     std::vector<const ITensorInfo *> inputs_weights_vector;
298     inputs_weights_vector.emplace_back(input_to_input_weights);
299     inputs_weights_vector.emplace_back(input_to_forget_weights);
300     inputs_weights_vector.emplace_back(input_to_cell_weights);
301     inputs_weights_vector.emplace_back(input_to_output_weights);
302     const QuantizationInfo qweights = input_to_input_weights->quantization_info(); // Weights quantization
303     const TensorInfo       input_weights(TensorShape(input_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
304     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(inputs_weights_vector, &input_weights, Window::DimY));
305 
306     // _concat_recurrent_weights
307     std::vector<const ITensorInfo *> recurrent_weights_vector;
308     recurrent_weights_vector.emplace_back(recurrent_to_input_weights);
309     recurrent_weights_vector.emplace_back(recurrent_to_forget_weights);
310     recurrent_weights_vector.emplace_back(recurrent_to_cell_weights);
311     recurrent_weights_vector.emplace_back(recurrent_to_output_weights);
312     const TensorInfo recurrent_weights(TensorShape(output_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
313     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(recurrent_weights_vector, &recurrent_weights, Window::DimY));
314 
315     // _concat_weights
316     std::vector<const ITensorInfo *> weights_vector;
317     weights_vector.emplace_back(&recurrent_weights);
318     weights_vector.emplace_back(&input_weights);
319     const TensorInfo weights(TensorShape(input_size + output_size, 4 * output_size), 1, DataType::QASYMM8, qweights);
320     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(weights_vector, &weights, Window::DimX));
321     // _transpose_weights
322     const TensorShape weights_transposed_shape(weights.tensor_shape()[1], weights.tensor_shape()[0]);
323     TensorInfo        weights_transposed = weights.clone()->set_is_resizable(true).set_tensor_shape(weights_transposed_shape);
324     ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(&weights, &weights_transposed));
325 
326     // _concat_inputs
327     std::vector<const ITensorInfo *> input_vector;
328     input_vector.emplace_back(input);
329     input_vector.emplace_back(output_state_in);
330     TensorInfo input_concatenated(TensorShape(output_size + input_size, batch_size), 1, DataType::QASYMM8, qasymm);
331     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(input_vector, &input_concatenated, Window::DimX));
332 
333     // _concat_bias
334     std::vector<const ITensorInfo *> bias_vector;
335     bias_vector.emplace_back(input_gate_bias);
336     bias_vector.emplace_back(forget_gate_bias);
337     bias_vector.emplace_back(cell_bias);
338     bias_vector.emplace_back(output_gate_bias);
339 
340     const TensorInfo bias_concatenated(TensorShape(4 * output_size), 1, DataType::S32);
341     ARM_COMPUTE_RETURN_ON_ERROR(NEConcatenateLayer::validate(bias_vector, &bias_concatenated, Window::DimX));
342 
343     // Invert the offset for gemmlowp
344     input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, -qasymm.uniform().offset));
345     weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, -qweights.uniform().offset));
346 
347     // _gemmlowp
348     const TensorInfo output_highp(TensorShape(4 * output_size, batch_size), 1, DataType::S32);
349     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpMatrixMultiplyCore::validate(&input_concatenated, &weights_transposed, nullptr, &output_highp));
350 
351     // Set the offset back
352     input_concatenated.set_quantization_info(QuantizationInfo(qasymm.uniform().scale, qasymm.uniform().offset));
353     weights_transposed.set_quantization_info(QuantizationInfo(qweights.uniform().scale, qweights.uniform().offset));
354 
355     const TensorInfo output_lowp(output_highp.tensor_shape(), 1, DataType::QSYMM16, qsymm_3);
356 
357     const float multiplier        = 4096.f * qasymm.uniform().scale * qweights.uniform().scale;
358     int32_t     output_multiplier = 0;
359     int32_t     output_shift      = 0;
360     ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
361 
362     // _output_stage
363     GEMMLowpOutputStageInfo info;
364     info.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
365     info.gemmlowp_multiplier = output_multiplier;
366     info.gemmlowp_shift      = output_shift;
367     info.output_data_type    = DataType::QSYMM16;
368     ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMLowpOutputStage::validate(&output_highp, &bias_concatenated, &output_lowp, info));
369 
370     TensorInfo input_gate_input;
371     TensorInfo forget_gate_input;
372     TensorInfo input_modulation_gate_input;
373     TensorInfo output_gate_input;
374 
375     if(batch_size > 1)
376     {
377         // _slice_input_tensor
378         input_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
379         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &input_gate_input, { 0, 0 }, { output_size, batch_size }));
380         // _slice_forget_tensor
381         forget_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
382         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &forget_gate_input, { output_size, 0 }, { 2 * output_size, batch_size }));
383         // _slice_cell_tensor
384         input_modulation_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
385         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &input_modulation_gate_input, { 2 * output_size, 0 }, { 3 * output_size, batch_size }));
386         // _slice_output_tensor
387         output_gate_input = TensorInfo(TensorShape(output_size, batch_size), 1, DataType::QSYMM16, qsymm_3);
388         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &output_gate_input, { 3 * output_size, 0 }, { 4 * output_size, batch_size }));
389     }
390     else
391     {
392         // _slice_input_tensor
393         input_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
394         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &input_gate_input, { 0 }, { output_size }));
395         // _slice_forget_tensor
396         forget_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
397         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &forget_gate_input, { output_size }, { 2 * output_size }));
398         // _slice_cell_tensor
399         input_modulation_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
400         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &input_modulation_gate_input, { 2 * output_size }, { 3 * output_size }));
401         // _slice_output_tensor
402         output_gate_input = TensorInfo(TensorShape(output_size), 1, DataType::QSYMM16, qsymm_3);
403         ARM_COMPUTE_RETURN_ON_ERROR(NESlice::validate(&output_lowp, &output_gate_input, { 3 * output_size }, { 4 * output_size }));
404     }
405 
406     // _sigmoid_forget_gate
407     const TensorInfo forget_gate_output(forget_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
408     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_gate_input, &forget_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
409     // _sigmoid_input_gate
410     const TensorInfo input_gate_output(input_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
411     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_gate_input, &input_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
412     // _tanh_modulation_gate
413     const TensorInfo input_modulation_gate_output(input_modulation_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
414     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_modulation_gate_input, &input_modulation_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f)));
415     // _sigmoid_output_gate
416     const TensorInfo output_gate_output(output_gate_input.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
417     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_gate_input, &output_gate_output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
418 
419     // _mul_forget_gate_cell_state
420     const TensorInfo cell_state_tmp1(forget_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_4);
421     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&forget_gate_output, cell_state_in, &cell_state_tmp1, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
422 
423     // _mul_input_gate_input_mod_gate
424     const TensorInfo cell_state_tmp2(input_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_4);
425     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&input_gate_output, &input_modulation_gate_output, &cell_state_tmp2, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
426 
427     // _add_cell_state_tmps
428     ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp1, &cell_state_tmp2, cell_state_out, ConvertPolicy::SATURATE));
429 
430     // _tanh_modulation_gate
431     const TensorInfo output_state_tmp(cell_state_out->tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
432     ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(cell_state_out, &output_state_tmp, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.0f, 1.0f)));
433 
434     // _mul_output_state_tmp_output_gate
435     const TensorInfo output_state_out_symm(output_gate_output.tensor_shape(), 1, DataType::QSYMM16, qsymm_0);
436     ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplication::validate(&output_state_tmp, &output_gate_output, &output_state_out_symm, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO));
437 
438     // _dequantize
439     const TensorInfo output_state_out_f32(output_state_out_symm.tensor_shape(), 1, DataType::F32);
440     ARM_COMPUTE_RETURN_ON_ERROR(NEDequantizationLayer::validate(&output_state_out_symm, &output_state_out_f32));
441 
442     // _quantize
443     ARM_COMPUTE_RETURN_ON_ERROR(NEQuantizationLayer::validate(&output_state_out_f32, output_state_out));
444 
445     if(cell_state_out->total_size() != 0)
446     {
447         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&cell_state_info, cell_state_out);
448         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&cell_state_info, cell_state_out);
449         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&cell_state_info, cell_state_out);
450     }
451 
452     if(output_state_out->total_size() != 0)
453     {
454         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&output_state_info, output_state_out);
455         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&output_state_info, output_state_out);
456         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&output_state_info, output_state_out);
457     }
458 
459     return Status{};
460 }
461 
run()462 void NELSTMLayerQuantized::run()
463 {
464     prepare();
465 
466     // Acquire all the temporaries
467     MemoryGroupResourceScope scope_mg(_memory_group);
468 
469     // Concat and transpose the input
470     _concat_inputs.run();
471 
472     // Run gemmlowp
473     _gemmlowp.run();
474     _output_stage.run();
475 
476     // Slice the results
477     _slice_input_tensor.run();
478     _slice_forget_tensor.run();
479     _slice_cell_tensor.run();
480     _slice_output_tensor.run();
481 
482     // Gates
483     // Forget gate
484     _sigmoid_forget_gate.run();
485 
486     // Input gate
487     _sigmoid_input_gate.run();
488 
489     // Input modulation gate
490     _tanh_modulation_gate.run();
491 
492     // Output gate
493     _sigmoid_output_gate.run();
494 
495     // Cell state (long term memory)
496     _mul1.run();
497     _mul2.run();
498     _add1.run();
499 
500     // Output state (short term memory)
501     _tanh_output_state.run();
502     _mul3.run();
503 
504     // Requantize output state from QSYMM16 to QASYMM8
505     _dequantize.run();
506     _quantize.run();
507 }
508 
prepare()509 void NELSTMLayerQuantized::prepare()
510 {
511     if(!_is_prepared)
512     {
513         _input_weights.allocator()->allocate();
514         _concat_input_weights.run();
515 
516         _input_to_input_weights->mark_as_unused();
517         _input_to_forget_weights->mark_as_unused();
518         _input_to_cell_weights->mark_as_unused();
519         _input_to_output_weights->mark_as_unused();
520 
521         _recurrent_weights.allocator()->allocate();
522         _concat_recurrent_weights.run();
523         _recurrent_to_input_weights->mark_as_unused();
524         _recurrent_to_forget_weights->mark_as_unused();
525         _recurrent_to_cell_weights->mark_as_unused();
526         _recurrent_to_output_weights->mark_as_unused();
527 
528         _weights.allocator()->allocate();
529         _concat_weights.run();
530 
531         _input_weights.mark_as_unused();
532         _input_weights.allocator()->free();
533         _recurrent_weights.mark_as_unused();
534         _recurrent_weights.allocator()->free();
535 
536         _weights_transposed.allocator()->allocate();
537         _transpose_weights.run();
538 
539         _weights.mark_as_unused();
540         _weights.allocator()->free();
541 
542         _bias.allocator()->allocate();
543         _concat_bias.run();
544         _input_gate_bias->mark_as_unused();
545         _forget_gate_bias->mark_as_unused();
546         _cell_bias->mark_as_unused();
547         _output_gate_bias->mark_as_unused();
548 
549         _is_prepared = true;
550     }
551 }
552 
553 } // namespace arm_compute
554