xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_NEQLSTMLAYER_H
25 #define ARM_COMPUTE_NEQLSTMLAYER_H
26 
27 #include "arm_compute/core/Types.h"
28 #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
29 #include "arm_compute/runtime/NEON/functions/NEArithmeticAddition.h"
30 #include "arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h"
31 #include "arm_compute/runtime/NEON/functions/NECopy.h"
32 #include "arm_compute/runtime/NEON/functions/NEDequantizationLayer.h"
33 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h"
34 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h"
35 #include "arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h"
36 #include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h"
37 #include "arm_compute/runtime/NEON/functions/NETranspose.h"
38 #include "arm_compute/runtime/common/LSTMParams.h"
39 
40 #include <memory>
41 
42 namespace arm_compute
43 {
44 // Forward declarations
45 class ITensor;
46 class ITensorInfo;
47 class NEQLSTMLayerNormalizationKernel;
48 namespace cpu
49 {
50 namespace kernels
51 {
52 class CpuGemmLowpMatrixAReductionKernel;
53 } // namespace kernels
54 } // namespace cpu
55 /** Basic function to run @ref NEQLSTMLayer
56  *
57  * This function calls the following kernels:
58  *
59  * -# @ref NEActivationLayer                                     Activation functions (tanh and logistic)
60  * -# @ref NEArithmeticAddition                                  Elementwise addition
61  * -# @ref NEArithmeticSubtraction                               Elementwise subtraction
62  * -# @ref NECopy                                                Copy kernel for copying output_state_out to output
63  * -# @ref NEGEMMLowpMatrixMultiplyCore                          Quantized matrix multiplication core. Accumulators are 32-bit integers
64  * -# @ref NEGEMMLowpOutputStage                                 Convert 32-bit integers into QSYMM16
65  * -# @ref cpu::kernels::CpuGemmLowpMatrixAReductionKernel            For precomputing effective biases to use
66  * -# @ref NEPixelWiseMultiplication                             Elementwise multiplication
67  * -# @ref NETranspose                                           Transpose function for reshaping the weights
68  * */
69 class NEQLSTMLayer : public IFunction
70 {
71 public:
72     /** Default constructor */
73     NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
74     /** Prevent instances of this class from being copied (As this class contains pointers) */
75     NEQLSTMLayer(const NEQLSTMLayer &) = delete;
76     /** Prevent instances of this class from being moved (As this class contains pointers) */
77     NEQLSTMLayer(NEQLSTMLayer &&) = delete;
78     /** Prevent instances of this class from being copied (As this class contains pointers) */
79     NEQLSTMLayer &operator=(const NEQLSTMLayer &) = delete;
80     /** Prevent instances of this class from being moved (As this class contains pointers) */
81     NEQLSTMLayer &operator=(NEQLSTMLayer &&) = delete;
82     /** Default destructor */
83     ~NEQLSTMLayer();
84     /** Initialize function's tensors.
85      *
86      * Valid data layouts:
87      * - All
88      *
89      * Valid data type configurations:
90      * |src0          |src1 - src6  |src7 -src9   |src10  |src11         |dst0   |dst1 - dst2       |
91      * |:-------------|:------------|:------------|:------|:-------------|:------|:-----------------|
92      * |QASYMM8_SIGNED|QASYMM8      |S32          |QSYMM16|QASYMM8_SIGNED|QSYMM16|QASYMM8_SIGNED    |
93      *
94      * @param[in]  input                       Source tensor. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED.
95      * @param[in]  input_to_forget_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
96      * @param[in]  input_to_cell_weights       2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
97      * @param[in]  input_to_output_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
98      * @param[in]  recurrent_to_forget_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
99      * @param[in]  recurrent_to_cell_weights   2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
100      * @param[in]  recurrent_to_output_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
101      * @param[in]  forget_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: S32.
102      * @param[in]  cell_bias                   1D weights tensor with dimensions [num_units]. Data type supported: S32.
103      * @param[in]  output_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: S32.
104      * @param[in]  cell_state_in               2D tensor with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
105      * @param[in]  output_state_in             2D tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
106      * @param[out] cell_state_out              Destination tensor. Output is a 2D tensor with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
107      * @param[out] output_state_out            Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
108      * @param[out] output                      Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
109      * @param[in]  lstm_params                 Weights tensors used in peephole, CIFG and layer normalization optimizations:
110      *                                         input_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
111      *                                         forget_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
112      *                                         cell_intermediate_scale    Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
113      *                                         output_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
114      *                                         hidden_state_zero          The zero point of the hidden state.
115      *                                         hidden_state_scale         The scale of the hidden state.
116      *                                         input_to_input_weights     (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
117      *                                         recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
118      *                                         cell_to_input_weights      (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16.
119      *                                         cell_to_forget_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
120      *                                         cell_to_output_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
121      *                                         input_gate_bias            (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32.
122      *                                         projection_weights         (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
123      *                                         projection_bias            (Optional) 1D weights tensor with dimensions [output_size]. S32.
124      *                                         input_layer_norm_weights   (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
125      *                                         forget_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
126      *                                         cell_layer_norm_weights    (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
127      *                                         output_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
128      *                                         cell_threshold             (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
129      *                                                                               If set to 0.0 then clipping is disabled.
130      *                                         projection_threshold       (Optional) The clipping threshold for the output from the projection layer, such that values are bound within
131      *                                                                               [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
132      */
133     void configure(const ITensor *input,
134                    const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights,
135                    const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
136                    const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
137                    const ITensor *cell_state_in, ITensor *output_state_in,
138                    ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
139                    const LSTMParams<ITensor> &lstm_params);
140 
141     /** Static function to check if given info will lead to a valid configuration of @ref NEQLSTMLayer
142      *
143      * @param[in] input                       Source tensor info. Input is a 2D tensor info with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED.
144      * @param[in] input_to_forget_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8.
145      * @param[in] input_to_cell_weights       2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8.
146      * @param[in] input_to_output_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8.
147      * @param[in] recurrent_to_forget_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8.
148      * @param[in] recurrent_to_cell_weights   2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8.
149      * @param[in] recurrent_to_output_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8.
150      * @param[in] forget_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: S32.
151      * @param[in] cell_bias                   1D weights tensor info with dimensions [num_units]. Data type supported: S32.
152      * @param[in] output_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: S32.
153      * @param[in] cell_state_in               2D tensor info with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
154      * @param[in] output_state_in             2D tensor info with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
155      * @param[in] cell_state_out              Destination tensor info. Output is a 2D tensor info with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
156      * @param[in] output_state_out            Destination tensor info. Output is a 2D tensor info with dimensions [output_size, batch_size].Data types supported: Same as @p input.
157      * @param[in] output                      Destination tensor info. Output is a 2D tensor info with dimensions [output_size, batch_size].Data types supported: Same as @p input.
158      * @param[in] lstm_params                 Weights tensors info used in peephole, CIFG and layer normalization optimizations:
159      *                                        input_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
160      *                                        forget_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
161      *                                        cell_intermediate_scale    Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
162      *                                        output_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
163      *                                        hidden_state_zero          The zero point of the hidden state.
164      *                                        hidden_state_scale         The scale of the hidden state.
165      *                                        input_to_input_weights     (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
166      *                                        recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
167      *                                        cell_to_input_weights      (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16.
168      *                                        cell_to_forget_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
169      *                                        cell_to_output_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
170      *                                        input_gate_bias            (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32.
171      *                                        projection_weights         (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
172      *                                        projection_bias            (Optional) 1D weights tensor with dimensions [output_size]. S32.
173      *                                        input_layer_norm_weights   (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
174      *                                        forget_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
175      *                                        cell_layer_norm_weights    (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
176      *                                        output_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
177      *                                        cell_threshold             (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
178      *                                                                              If set to 0.0 then clipping is disabled.
179      *                                        projection_threshold       (Optional) The clipping threshold for the output from the projection layer, such that values are bound within
180      *                                                                              [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
181      * @return a status
182      */
183     static Status validate(const ITensorInfo *input,
184                            const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
185                            const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
186                            const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
187                            const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
188                            const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
189                            const LSTMParams<ITensorInfo> &lstm_params);
190 
191     // Inherited methods overridden:
192     void run() override;
193     void prepare() override;
194 
195 private:
196     enum class LayerNormGate : uint8_t
197     {
198         Forget,
199         Cell,
200         Input,
201         Output,
202         Count
203     };
204     static constexpr uint8_t  _layer_norm_count                    = static_cast<uint8_t>(LayerNormGate::Count);
205     static constexpr uint32_t _out_state_output_size_dimension_idx = 0;
206 
207     /** Internal method to configure matrix multiplication plus output stage of each gate.
208      *
209      * @param[in] mm             Matrix multiplication function to use.
210      * @param[in] outstage       Output stage function to use.
211      * @param[in] gemmlowp_info  GEMMLowp metadata to be used by the output stage.
212      * @param[in] mm_input       Input tensor to matrix multiplication function.
213      * @param[in] mm_weights     Weights tensor to matrix multiplication function.
214      * @param[in] bias           Bias tensor to matrix multiplication function.
215      * @param[in] outstage_res   Tensor to be used for storing the result of the output stage.
216      * @param[in] gemmlowp_scale Real multiplier to be used computing multiplier and shift for requantization.
217      * @param[in] mm_res_info    Tensor info to be used to initialize matrix multiplication result tensor.
218      * @param[in] mm_res_info    Tensor info to be used to initialize output stage result tensor.
219      *
220      */
221     void configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
222                       const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias, Tensor *mm_res,
223                       Tensor *outstage_res, float gemmlowp_scale,
224                       const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info);
225 
226     MemoryGroup _memory_group;
227 
228     /** A small internel kernel do the copy between two tensors */
229     class TensorCopyKernel
230     {
231         static constexpr uint32_t max_dimension_supported = 2;
232 
233         ITensor *_src{ nullptr };
234         ITensor *_dst{ nullptr };
235         size_t   _row_size{};
236         Window   _window{};
237 
238     public:
239         /** Destructor */
240         ~TensorCopyKernel();
241         /** Static function to check if given info will lead to a valid configuration of @ref NEQLSTMLayer::TensorCopyKernel
242          *
243          * @param[in] src Source tensor info.
244          * @param[in] dst Destination tensor info
245          *
246          * @return a status
247          */
248         static Status validate(const ITensorInfo &src, const ITensorInfo &dst);
249         /** Set the input and output tensors.
250          *
251          * @param[in]  src Source tensor
252          * @param[out] dst Destination tensor
253          */
254         void configure(ITensor &src, ITensor &dst);
255         /** run the kernel */
256         void run();
257     };
258 
259     // Functions used
260 
261     NEDequantizationLayer                                            _dequantize_input_to_forget_weights;
262     NEQuantizationLayer                                              _quantize_input_to_forget_weights;
263     NETranspose                                                      _transpose_input_to_forget_weights;
264     NETranspose                                                      _transpose_input_to_cell_weights;
265     NETranspose                                                      _transpose_input_to_output_weights;
266     NETranspose                                                      _transpose_input_to_input_weights;
267     NETranspose                                                      _transpose_recurrent_to_forget_weights;
268     NETranspose                                                      _transpose_recurrent_to_cell_weights;
269     NETranspose                                                      _transpose_recurrent_to_output_weights;
270     NETranspose                                                      _transpose_recurrent_to_input_weights;
271     NETranspose                                                      _transpose_projection_weights;
272     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_input_reduction;
273     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_input_reduction;
274     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_forget_reduction;
275     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_forget_reduction;
276     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_cell_reduction;
277     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_cell_reduction;
278     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_output_reduction;
279     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_output_reduction;
280     std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _projection_reduction;
281     NEArithmeticAddition                                             _projection_bias_add;
282     NEGEMMLowpMatrixMultiplyCore                                     _mm_input_to_forget;
283     NEGEMMLowpMatrixMultiplyCore                                     _mm_recurrent_to_forget;
284     NEPixelWiseMultiplication                                        _pixelwise_mul_cell_to_forget;
285     NEGEMMLowpOutputStage                                            _input_to_forget_outstage;
286     NEGEMMLowpOutputStage                                            _recurrent_to_forget_outstage;
287     NEGEMMLowpOutputStage                                            _cell_to_forget_outstage;
288     NEArithmeticAddition                                             _accumulate_input_recurrent_forget;
289     NEArithmeticAddition                                             _accumulate_cell_forget;
290     NEActivationLayer                                                _forget_gate_sigmoid;
291     NEGEMMLowpMatrixMultiplyCore                                     _mm_input_to_cell;
292     NEGEMMLowpOutputStage                                            _input_to_cell_outstage;
293     NEGEMMLowpMatrixMultiplyCore                                     _mm_recurrent_to_cell;
294     NEGEMMLowpOutputStage                                            _recurrent_to_cell_outstage;
295     NEArithmeticAddition                                             _accumulate_input_recurrent_modulation;
296     NEActivationLayer                                                _cell_gate_tanh;
297     NEArithmeticSubtraction                                          _input_gate_sub;
298     NEGEMMLowpMatrixMultiplyCore                                     _mm_input_to_input;
299     NEGEMMLowpOutputStage                                            _input_to_input_outstage;
300     NEGEMMLowpMatrixMultiplyCore                                     _mm_recurrent_to_input;
301     NEGEMMLowpOutputStage                                            _recurrent_to_input_outstage;
302     NEArithmeticAddition                                             _accumulate_input_recurrent_input;
303     NEPixelWiseMultiplication                                        _pixelwise_mul_cell_to_input;
304     NEGEMMLowpOutputStage                                            _cell_to_input_outstage;
305     NEArithmeticAddition                                             _accumulate_cell_input;
306     NEActivationLayer                                                _input_gate_sigmoid;
307     NEPixelWiseMultiplication                                        _pixelwise_mul_forget_cell;
308     NEPixelWiseMultiplication                                        _pixelwise_mul_input_cell;
309     NEArithmeticAddition                                             _add_forget_cell;
310     NEActivationLayer                                                _cell_clip;
311     NEGEMMLowpMatrixMultiplyCore                                     _mm_input_to_output;
312     NEGEMMLowpOutputStage                                            _input_to_output_outstage;
313     NEGEMMLowpMatrixMultiplyCore                                     _mm_recurrent_to_output;
314     NEGEMMLowpOutputStage                                            _recurrent_to_output_outstage;
315     NEArithmeticAddition                                             _accumulate_input_recurrent_output;
316     NEPixelWiseMultiplication                                        _pixelwise_mul_cell_to_output;
317     NEGEMMLowpOutputStage                                            _cell_to_output_outstage;
318     NEArithmeticAddition                                             _accumulate_cell_to_output;
319     NEActivationLayer                                                _output_gate_sigmoid;
320     NEActivationLayer                                                _hidden_tanh;
321     NEPixelWiseMultiplication                                        _pixelwise_mul_hidden;
322     NEGEMMLowpOutputStage                                            _hidden_outstage;
323     NEGEMMLowpMatrixMultiplyCore                                     _mm_projection;
324     NEGEMMLowpOutputStage                                            _projection_outstage;
325     NEArithmeticAddition                                             _accumulate_projection;
326     NEActivationLayer                                                _projection_clip;
327 
328     TensorCopyKernel _projection_bias_copy;
329     TensorCopyKernel _projection_output_to_accumulate_copy;
330     TensorCopyKernel _projection_accumulate_to_output_copy;
331     TensorCopyKernel _hidden_to_output_copy;
332 
333     std::array<std::unique_ptr<NEQLSTMLayerNormalizationKernel>, _layer_norm_count> _layer_norms;
334 
335     NECopy _copy_output;
336 
337     // Tensor pointers
338     const ITensor *_input_to_input_weights
339     {
340         nullptr
341     };
342     const ITensor *_recurrent_to_input_weights{ nullptr };
343     const ITensor *_projection_bias{ nullptr };
344     const ITensor *_input_to_forget_weights{ nullptr };
345     const ITensor *_input_to_cell_weights{ nullptr };
346     const ITensor *_input_to_output_weights{ nullptr };
347     const ITensor *_recurrent_to_forget_weights{ nullptr };
348     const ITensor *_recurrent_to_cell_weights{ nullptr };
349     const ITensor *_recurrent_to_output_weights{ nullptr };
350     const ITensor *_projection_weights{ nullptr };
351     std::array<const ITensor *, _layer_norm_count> _layer_norm_weights{};
352     std::array<const ITensor *, _layer_norm_count> _layer_norm_bias{};
353 
354     using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type;
getGateIndex(LayerNormGate g)355     inline LayerNormIndexType getGateIndex(LayerNormGate g)
356     {
357         return static_cast<LayerNormIndexType>(g);
358     }
359 
set_layer_norm_weight(const ITensor * t,LayerNormGate g)360     inline void set_layer_norm_weight(const ITensor *t, LayerNormGate g)
361     {
362         _layer_norm_weights[getGateIndex(g)] = t;
363     }
364 
set_layer_norm_bias(const ITensor * t,LayerNormGate g)365     inline void set_layer_norm_bias(const ITensor *t, LayerNormGate g)
366     {
367         _layer_norm_bias[getGateIndex(g)] = t;
368     }
369 
get_layer_norm_weight(LayerNormGate g)370     inline const ITensor *get_layer_norm_weight(LayerNormGate g)
371     {
372         return _layer_norm_weights[getGateIndex(g)];
373     }
374 
get_layer_norm_bias(LayerNormGate g)375     inline const ITensor *get_layer_norm_bias(LayerNormGate g)
376     {
377         return _layer_norm_bias[getGateIndex(g)];
378     }
379 
get_layer_norm(LayerNormGate g)380     inline std::unique_ptr<NEQLSTMLayerNormalizationKernel> &get_layer_norm(LayerNormGate g)
381     {
382         return _layer_norms[getGateIndex(g)];
383     }
384 
385     void configure_layer_norm(LayerNormGate g, const ITensor *in);
386     static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias);
387 
388     // Temporary tensors
389     Tensor _input_to_forget_weights_f32{ nullptr };
390     Tensor _input_to_forget_weights_symm8{ nullptr };
391 
392     Tensor _input_to_forget_weights_transposed{ nullptr };
393     Tensor _input_to_cell_weights_transposed{ nullptr };
394     Tensor _input_to_output_weights_transposed{ nullptr };
395     Tensor _input_to_input_weights_transposed{ nullptr };
396     Tensor _recurrent_to_forget_weights_transposed{ nullptr };
397     Tensor _recurrent_to_cell_weights_transposed{ nullptr };
398     Tensor _recurrent_to_output_weights_transposed{ nullptr };
399     Tensor _recurrent_to_input_weights_transposed{ nullptr };
400     Tensor _projection_weights_transposed{ nullptr };
401     Tensor _input_to_input_eff_bias{ nullptr };
402     Tensor _recurrent_to_input_eff_bias{ nullptr };
403     Tensor _input_to_forget_eff_bias{ nullptr };
404     Tensor _recurrent_to_forget_eff_bias{ nullptr };
405     Tensor _input_to_cell_eff_bias{ nullptr };
406     Tensor _recurrent_to_cell_eff_bias{ nullptr };
407     Tensor _input_to_output_eff_bias{ nullptr };
408     Tensor _recurrent_to_output_eff_bias{ nullptr };
409     Tensor _projection_reduction_res{ nullptr };
410     Tensor _projection_eff_bias{ nullptr };
411     Tensor _mm_input_to_forget_res{ nullptr };
412     Tensor _mm_recurrent_to_forget_res{ nullptr };
413     Tensor _mul_cell_to_forget_res{ nullptr };
414     Tensor _input_to_forget_outstage_res{ nullptr };
415     Tensor _cell_to_forget_outstage_res{ nullptr };
416     Tensor _recurrent_to_forget_outstage_res{ nullptr };
417     Tensor _forget_gate{ nullptr };
418     Tensor _mm_input_to_cell_res{ nullptr };
419     Tensor _input_to_cell_outstage_res{ nullptr };
420     Tensor _mm_recurrent_to_cell_res{ nullptr };
421     Tensor _recurrent_to_cell_outstage_res{ nullptr };
422     Tensor _cell_gate{ nullptr };
423     Tensor _mul_input_cell_res{ nullptr };
424     Tensor _mm_input_to_input_res{ nullptr };
425     Tensor _input_to_input_outstage_res{ nullptr };
426     Tensor _mm_recurrent_to_input_res{ nullptr };
427     Tensor _mul_cell_to_input_res{ nullptr };
428     Tensor _cell_to_input_outstage_res{ nullptr };
429     Tensor _recurrent_to_input_outstage_res{ nullptr };
430     Tensor _input_gate{ nullptr };
431     Tensor _mm_input_to_output_res{ nullptr };
432     Tensor _input_to_output_outstage_res{ nullptr };
433     Tensor _mm_recurrent_to_output_res{ nullptr };
434     Tensor _mul_cell_to_output_res{ nullptr };
435     Tensor _cell_to_output_outstage_res{ nullptr };
436     Tensor _recurrent_to_output_outstage_res{ nullptr };
437     Tensor _output_gate{ nullptr };
438     Tensor _hidden_mul_res{ nullptr };
439     Tensor _hidden_gate{ nullptr };
440     Tensor _mm_projection_res{ nullptr };
441     Tensor _projection_outstage_res{ nullptr };
442     Tensor _projection_out_res{ nullptr };
443     Tensor _projection_accumulate_res{ nullptr };
444     Tensor _ones{ nullptr };
445     std::array<Tensor, _layer_norm_count> _layer_norm_output{};
446 
get_layer_norm_output(LayerNormGate g)447     inline Tensor &get_layer_norm_output(LayerNormGate g)
448     {
449         return _layer_norm_output[getGateIndex(g)];
450     }
451 
452     bool _is_prepared{ false };
453     bool _has_cifg{ false };
454     bool _has_cell_clipping{ false };
455     bool _has_projection{ false };
456     bool _has_projection_clipping{ false };
457     bool _has_peephole{ false };
458     bool _has_layer_norm{ false };
459     bool _projection_tensor_copy_required{ false };
460     bool _convert_input_to_forget_weights_to_qsymm8{ false };
461 };
462 } // namespace arm_compute
463 #endif /* ARM_COMPUTE_NEQLSTMLAYER_H */
464