xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/runtime/CL/functions/CLQLSTMLayer.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CLQLSTMLAYER_H
25 #define ARM_COMPUTE_CLQLSTMLAYER_H
26 
27 #include "arm_compute/core/Types.h"
28 #include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
29 #include "arm_compute/runtime/CL/functions/CLCopy.h"
30 #include "arm_compute/runtime/CL/functions/CLElementwiseOperations.h"
31 #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
32 #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
33 #include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
34 #include "arm_compute/runtime/CL/functions/CLTranspose.h"
35 
36 #include "arm_compute/runtime/common/LSTMParams.h"
37 
38 namespace arm_compute
39 {
40 // Forward declarations
41 class CLCompileContext;
42 class ICLTensor;
43 class CLQLSTMLayerNormalizationKernel;
44 class ITensorInfo;
45 namespace opencl
46 {
47 namespace kernels
48 {
49 class ClGemmLowpMatrixAReductionKernel;
50 } // namespace kernels
51 } // namespace opencl
52 
53 /** Basic function to run @ref CLQLSTMLayer
54  *
55  * This function calls the following CL functions/kernels:
56  *
57  * -# @ref CLActivationLayer                                     Activation functions (tanh and logistic)
58  * -# @ref CLCopy                                                Copy function for copying output_state_out to output
59  * -# @ref CLArithmeticAddition                                  Elementwise addition and subtraction
60  * -# @ref CLGEMMLowpMatrixMultiplyCore                          Quantized matrix multiplication core. Accumulators are 32-bit integers
61  * -# @ref CLGEMMLowpOutputStage   Convert 32-bit integers into QSYMM16
62  * -# @ref opencl::kernels::ClGemmLowpMatrixAReductionKernel                      For precomputing effective biases to use
63  * -# @ref CLPixelWiseMultiplication                             Elementwise multiplication
64  * -# @ref CLTranspose                                           Transpose function for reshaping the weights
65  * */
66 class CLQLSTMLayer : public IFunction
67 {
68 public:
69     /** Default constructor */
70     CLQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
71     /** Prevent instances of this class from being copied (As this class contains pointers) */
72     CLQLSTMLayer(const CLQLSTMLayer &) = delete;
73     /** Default move constructor */
74     CLQLSTMLayer(CLQLSTMLayer &&) = default;
75     /** Prevent instances of this class from being copied (As this class contains pointers) */
76     CLQLSTMLayer &operator=(const CLQLSTMLayer &) = delete;
77     /** Default move assignment operator */
78     CLQLSTMLayer &operator=(CLQLSTMLayer &&) = default;
79     /** Default destructor */
80     ~CLQLSTMLayer();
81     /** Initialize function's tensors.
82      *
83      * Valid data layouts:
84      * - All
85      *
86      * Valid data type configurations:
87      * |src0          |src1 - src6  |src7 -src9   |src10  |src11         |dst0   |dst1 - dst2       |
88      * |:-------------|:------------|:------------|:------|:-------------|:------|:-----------------|
89      * |QASYMM8_SIGNED|QASYMM8      |S32          |QSYMM16|QASYMM8_SIGNED|QSYMM16|QASYMM8_SIGNED    |
90      *
91      * @param[in]  input                       Source tensor. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED.
92      * @param[in]  input_to_forget_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
93      * @param[in]  input_to_cell_weights       2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
94      * @param[in]  input_to_output_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
95      * @param[in]  recurrent_to_forget_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
96      * @param[in]  recurrent_to_cell_weights   2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
97      * @param[in]  recurrent_to_output_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
98      * @param[in]  forget_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: S32.
99      * @param[in]  cell_bias                   1D weights tensor with dimensions [num_units]. Data type supported: S32.
100      * @param[in]  output_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: S32.
101      * @param[in]  cell_state_in               2D tensor with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
102      * @param[in]  output_state_in             2D tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
103      * @param[out] cell_state_out              Destination tensor. Output is a 2D tensor with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
104      * @param[out] output_state_out            Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
105      * @param[out] output                      Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
106      * @param[in]  lstm_params                 Weights tensors used in peephole, CIFG and layer normalization optimizations:
107      *                                         input_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
108      *                                         forget_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
109      *                                         cell_intermediate_scale    Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
110      *                                         output_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
111      *                                         hidden_state_zero          The zero point of the hidden state.
112      *                                         hidden_state_scale         The scale of the hidden state.
113      *                                         input_to_input_weights     (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
114      *                                         recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
115      *                                         cell_to_input_weights      (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16.
116      *                                         cell_to_forget_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
117      *                                         cell_to_output_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
118      *                                         input_gate_bias            (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32.
119      *                                         projection_weights         (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
120      *                                         projection_bias            (Optional) 1D weights tensor with dimensions [output_size]. S32.
121      *                                         input_layer_norm_weights   (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
122      *                                         forget_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
123      *                                         cell_layer_norm_weights    (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
124      *                                         output_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
125      *                                         cell_threshold             (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
126      *                                                                               If set to 0.0 then clipping is disabled.
127      *                                         projection_threshold       (Optional) The clipping threshold for the output from the projection layer, such that values are bound within
128      *                                                                               [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
129      */
130     void configure(const ICLTensor *input,
131                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
132                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
133                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
134                    ICLTensor *cell_state_in, ICLTensor *output_state_in,
135                    ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
136                    const LSTMParams<ICLTensor> &lstm_params);
137 
138     /** Initialize function's tensors.
139      *
140      * @param[in]  compile_context             The compile context to be used.
141      * @param[in]  input                       Source tensor. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED.
142      * @param[in]  input_to_forget_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
143      * @param[in]  input_to_cell_weights       2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
144      * @param[in]  input_to_output_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
145      * @param[in]  recurrent_to_forget_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
146      * @param[in]  recurrent_to_cell_weights   2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
147      * @param[in]  recurrent_to_output_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
148      * @param[in]  forget_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: S32.
149      * @param[in]  cell_bias                   1D weights tensor with dimensions [num_units]. Data type supported: S32.
150      * @param[in]  output_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: S32.
151      * @param[in]  cell_state_in               2D tensor with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
152      * @param[in]  output_state_in             2D tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
153      * @param[out] cell_state_out              Destination tensor. Output is a 2D tensor with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
154      * @param[out] output_state_out            Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
155      * @param[out] output                      Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
156      * @param[in]  lstm_params                 Weights tensors used in peephole, CIFG and layer normalization optimizations:
157      *                                         input_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
158      *                                         forget_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
159      *                                         cell_intermediate_scale    Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
160      *                                         output_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
161      *                                         hidden_state_zero          The zero point of the hidden state.
162      *                                         hidden_state_scale         The scale of the hidden state.
163      *                                         input_to_input_weights     (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
164      *                                         recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
165      *                                         cell_to_input_weights      (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16.
166      *                                         cell_to_forget_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
167      *                                         cell_to_output_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
168      *                                         input_gate_bias            (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32.
169      *                                         projection_weights         (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
170      *                                         projection_bias            (Optional) 1D weights tensor with dimensions [output_size]. S32.
171      *                                         input_layer_norm_weights   (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
172      *                                         forget_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
173      *                                         cell_layer_norm_weights    (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
174      *                                         output_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
175      *                                         cell_threshold             (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
176      *                                                                               If set to 0.0 then clipping is disabled.
177      *                                         projection_threshold       (Optional) The clipping threshold for the output from the projection layer, such that values are bound within
178      *                                                                               [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
179      */
180     void configure(const CLCompileContext &compile_context, const ICLTensor *input,
181                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
182                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
183                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
184                    ICLTensor *cell_state_in, ICLTensor *output_state_in,
185                    ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
186                    const LSTMParams<ICLTensor> &lstm_params);
187 
188     /** Static function to check if given info will lead to a valid configuration of @ref CLQLSTMLayer
189      *
190      * @param[in] input                       Source tensor info. Input is a 2D tensor info with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED.
191      * @param[in] input_to_forget_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8.
192      * @param[in] input_to_cell_weights       2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8.
193      * @param[in] input_to_output_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8.
194      * @param[in] recurrent_to_forget_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8.
195      * @param[in] recurrent_to_cell_weights   2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8.
196      * @param[in] recurrent_to_output_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8.
197      * @param[in] forget_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: S32.
198      * @param[in] cell_bias                   1D weights tensor info with dimensions [num_units]. Data type supported: S32.
199      * @param[in] output_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: S32.
200      * @param[in] cell_state_in               2D tensor info with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
201      * @param[in] output_state_in             2D tensor info with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
202      * @param[in] cell_state_out              Destination tensor info. Output is a 2D tensor info with dimensions [num_units, batch_size]. Data type supported:  QSYMM16.
203      * @param[in] output_state_out            Destination tensor info. Output is a 2D tensor info with dimensions [output_size, batch_size].Data types supported: Same as @p input.
204      * @param[in] output                      Destination tensor info. Output is a 2D tensor info with dimensions [output_size, batch_size].Data types supported: Same as @p input.
205      * @param[in] lstm_params                 Weights tensors info used in peephole, CIFG and layer normalization optimizations:
206      *                                        input_intermediate_scale   Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
207      *                                        forget_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
208      *                                        cell_intermediate_scale    Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
209      *                                        output_intermediate_scale  Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
210      *                                        hidden_state_zero          The zero point of the hidden state.
211      *                                        hidden_state_scale         The scale of the hidden state.
212      *                                        input_to_input_weights     (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8.
213      *                                        recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
214      *                                        cell_to_input_weights      (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16.
215      *                                        cell_to_forget_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
216      *                                        cell_to_output_weights     (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
217      *                                        input_gate_bias            (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32.
218      *                                        projection_weights         (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8.
219      *                                        projection_bias            (Optional) 1D weights tensor with dimensions [output_size]. S32.
220      *                                        input_layer_norm_weights   (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
221      *                                        forget_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
222      *                                        cell_layer_norm_weights    (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
223      *                                        output_layer_norm_weights  (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16.
224      *                                        cell_threshold             (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
225      *                                                                              If set to 0.0 then clipping is disabled.
226      *                                        projection_threshold       (Optional) The clipping threshold for the output from the projection layer, such that values are bound within
227      *                                                                              [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled.
228      * @return a status
229      */
230     static Status validate(const ITensorInfo *input,
231                            const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
232                            const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
233                            const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
234                            const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
235                            const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
236                            const LSTMParams<ITensorInfo> &lstm_params);
237 
238     // Inherited methods overridden:
239     void run() override;
240     void prepare() override;
241 
242 private:
243     enum class LayerNormGate : uint8_t
244     {
245         Forget,
246         Cell,
247         Input,
248         Output,
249         Count
250     };
251     static constexpr uint8_t  _layer_norm_count                    = static_cast<uint8_t>(LayerNormGate::Count);
252     static constexpr uint32_t _out_state_output_size_dimension_idx = 0;
253 
254     /** Internal method to configure matrix multiplication plus output stage of each gate.
255      *
256      * @param[in] compile_context The compile context to be used.
257      * @param[in] mm              Matrix multiplication function to use.
258      * @param[in] outstage        Output stage function to use.
259      * @param[in] gemmlowp_info   GEMMLowp metadata to be used by the output stage.
260      * @param[in] mm_input        Input tensor to matrix multiplication function.
261      * @param[in] mm_weights      Weights tensor to matrix multiplication function.
262      * @param[in] bias            Bias tensor to matrix multiplication function.
263      * @param[in] outstage_res    Tensor to be used for storing the result of the output stage.
264      * @param[in] gemmlowp_scale  Real multiplier to be used computing multiplier and shift for requantization.
265      * @param[in] mm_res_info     Tensor info to be used to initialize matrix multiplication result tensor.
266      * @param[in] mm_res_info     Tensor info to be used to initialize output stage result tensor.
267      *
268      */
269     void configure_mm(const CLCompileContext &compile_context, CLGEMMLowpMatrixMultiplyCore &mm, CLGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info,
270                       const ICLTensor *mm_input, const ICLTensor *mm_weights, const ICLTensor *bias, CLTensor *mm_res,
271                       CLTensor *outstage_res, float gemmlowp_scale,
272                       const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info);
273 
274     MemoryGroup _memory_group{};
275 
276     /** A small internel kernel do the copy between two tensors */
277     class TensorCopyKernel
278     {
279         static constexpr uint32_t max_dimension_supported = 2;
280 
281         ICLTensor *_src{ nullptr };
282         ICLTensor *_dst{ nullptr };
283         size_t     _row_size{};
284         Window     _window{};
285 
286     public:
287         /** Static function to check if given info will lead to a valid configuration of @ref CLQLSTMLayer::TensorCopyKernel
288          *
289          * @param[in] src Source tensor info.
290          * @param[in] dst Destination tensor info
291          *
292          * @return a status
293          */
294         static Status validate(const ITensorInfo &src, const ITensorInfo &dst);
295         /** Set the input and output tensors.
296          *
297          * @param[in]  src Source tensor
298          * @param[out] dst Destination tensor
299          */
300         void configure(ICLTensor &src, ICLTensor &dst);
301         /** run the kernel */
302         void run();
303     };
304 
305     // Functions used
306     CLTranspose                                                        _transpose_input_to_forget_weights{};
307     CLTranspose                                                        _transpose_input_to_cell_weights{};
308     CLTranspose                                                        _transpose_input_to_output_weights{};
309     CLTranspose                                                        _transpose_input_to_input_weights{};
310     CLTranspose                                                        _transpose_recurrent_to_forget_weights{};
311     CLTranspose                                                        _transpose_recurrent_to_cell_weights{};
312     CLTranspose                                                        _transpose_recurrent_to_output_weights{};
313     CLTranspose                                                        _transpose_recurrent_to_input_weights{};
314     CLTranspose                                                        _transpose_projection_weights{};
315     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_input_reduction;
316     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_input_reduction;
317     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_forget_reduction;
318     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_forget_reduction;
319     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_cell_reduction;
320     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_cell_reduction;
321     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _input_to_output_reduction;
322     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _recurrent_to_output_reduction;
323     std::unique_ptr<opencl::kernels::ClGemmLowpMatrixAReductionKernel> _projection_reduction;
324     CLArithmeticAddition                                               _projection_bias_add{};
325     CLGEMMLowpMatrixMultiplyCore                                       _mm_input_to_forget{};
326     CLGEMMLowpMatrixMultiplyCore                                       _mm_recurrent_to_forget{};
327     CLPixelWiseMultiplication                                          _pixelwise_mul_cell_to_forget{};
328     CLGEMMLowpOutputStage                                              _input_to_forget_outstage{};
329     CLGEMMLowpOutputStage                                              _recurrent_to_forget_outstage{};
330     CLGEMMLowpOutputStage                                              _cell_to_forget_outstage{};
331     CLArithmeticAddition                                               _accumulate_input_recurrent_forget{};
332     CLArithmeticAddition                                               _accumulate_cell_forget{};
333     CLActivationLayer                                                  _forget_gate_sigmoid{};
334     CLGEMMLowpMatrixMultiplyCore                                       _mm_input_to_cell{};
335     CLGEMMLowpOutputStage                                              _input_to_cell_outstage{};
336     CLGEMMLowpMatrixMultiplyCore                                       _mm_recurrent_to_cell{};
337     CLGEMMLowpOutputStage                                              _recurrent_to_cell_outstage{};
338     CLArithmeticAddition                                               _accumulate_input_recurrent_modulation{};
339     CLActivationLayer                                                  _cell_gate_tanh{};
340     CLArithmeticSubtraction                                            _input_gate_sub{};
341     CLGEMMLowpMatrixMultiplyCore                                       _mm_input_to_input{};
342     CLGEMMLowpOutputStage                                              _input_to_input_outstage{};
343     CLGEMMLowpMatrixMultiplyCore                                       _mm_recurrent_to_input{};
344     CLGEMMLowpOutputStage                                              _recurrent_to_input_outstage{};
345     CLArithmeticAddition                                               _accumulate_input_recurrent_input{};
346     CLPixelWiseMultiplication                                          _pixelwise_mul_cell_to_input{};
347     CLGEMMLowpOutputStage                                              _cell_to_input_outstage{};
348     CLArithmeticAddition                                               _accumulate_cell_input{};
349     CLActivationLayer                                                  _input_gate_sigmoid{};
350     CLPixelWiseMultiplication                                          _pixelwise_mul_forget_cell{};
351     CLPixelWiseMultiplication                                          _pixelwise_mul_input_cell{};
352     CLArithmeticAddition                                               _add_forget_cell{};
353     CLActivationLayer                                                  _cell_clip{};
354     CLGEMMLowpMatrixMultiplyCore                                       _mm_input_to_output{};
355     CLGEMMLowpOutputStage                                              _input_to_output_outstage{};
356     CLGEMMLowpMatrixMultiplyCore                                       _mm_recurrent_to_output{};
357     CLGEMMLowpOutputStage                                              _recurrent_to_output_outstage{};
358     CLArithmeticAddition                                               _accumulate_input_recurrent_output{};
359     CLPixelWiseMultiplication                                          _pixelwise_mul_cell_to_output{};
360     CLGEMMLowpOutputStage                                              _cell_to_output_outstage{};
361     CLArithmeticAddition                                               _accumulate_cell_to_output{};
362     CLActivationLayer                                                  _output_gate_sigmoid{};
363     CLActivationLayer                                                  _hidden_tanh{};
364     CLPixelWiseMultiplication                                          _pixelwise_mul_hidden{};
365     CLGEMMLowpOutputStage                                              _hidden_outstage{};
366     CLGEMMLowpMatrixMultiplyCore                                       _mm_projection{};
367     CLGEMMLowpOutputStage                                              _projection_outstage{};
368     CLArithmeticAddition                                               _accumulate_projection{};
369     CLActivationLayer                                                  _projection_clip{};
370     std::array<std::unique_ptr<CLQLSTMLayerNormalizationKernel>, _layer_norm_count> _layer_norms;
371     CLCopy _copy_output;
372 
373     TensorCopyKernel _projection_bias_copy{};
374     TensorCopyKernel _projection_output_to_accumulate_copy{};
375     TensorCopyKernel _projection_accumulate_to_output_copy{};
376     TensorCopyKernel _hidden_to_output_copy{};
377 
378     // Tensor pointers
379     const ICLTensor *_input_to_input_weights
380     {
381         nullptr
382     };
383     const ICLTensor *_recurrent_to_input_weights{ nullptr };
384     const ICLTensor *_projection_bias{ nullptr };
385     const ICLTensor *_input_to_forget_weights{ nullptr };
386     const ICLTensor *_input_to_cell_weights{ nullptr };
387     const ICLTensor *_input_to_output_weights{ nullptr };
388     const ICLTensor *_recurrent_to_forget_weights{ nullptr };
389     const ICLTensor *_recurrent_to_cell_weights{ nullptr };
390     const ICLTensor *_recurrent_to_output_weights{ nullptr };
391     const ICLTensor *_projection_weights{ nullptr };
392     std::array<const ICLTensor *, _layer_norm_count> _layer_norm_weights{ {} };
393     std::array<const ICLTensor *, _layer_norm_count> _layer_norm_bias{ {} };
394 
395     using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type;
getGateIndex(LayerNormGate g)396     inline LayerNormIndexType getGateIndex(LayerNormGate g)
397     {
398         return static_cast<LayerNormIndexType>(g);
399     }
400 
set_layer_norm_weight(const ICLTensor * t,LayerNormGate g)401     inline void set_layer_norm_weight(const ICLTensor *t, LayerNormGate g)
402     {
403         _layer_norm_weights[getGateIndex(g)] = t;
404     }
405 
set_layer_norm_bias(const ICLTensor * t,LayerNormGate g)406     inline void set_layer_norm_bias(const ICLTensor *t, LayerNormGate g)
407     {
408         _layer_norm_bias[getGateIndex(g)] = t;
409     }
410 
get_layer_norm_weight(LayerNormGate g)411     inline const ICLTensor *get_layer_norm_weight(LayerNormGate g)
412     {
413         return _layer_norm_weights[getGateIndex(g)];
414     }
415 
get_layer_norm_bias(LayerNormGate g)416     inline const ICLTensor *get_layer_norm_bias(LayerNormGate g)
417     {
418         return _layer_norm_bias[getGateIndex(g)];
419     }
420 
get_layer_norm(LayerNormGate g)421     inline CLQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g)
422     {
423         return *_layer_norms[getGateIndex(g)];
424     }
425 
426     inline void configure_layer_norm(LayerNormGate g, const ICLTensor *in);
427     inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias);
428 
429     // Temporary tensors
430     CLTensor _input_to_forget_weights_transposed{ nullptr };
431     CLTensor _input_to_cell_weights_transposed{ nullptr };
432     CLTensor _input_to_output_weights_transposed{ nullptr };
433     CLTensor _input_to_input_weights_transposed{ nullptr };
434     CLTensor _recurrent_to_forget_weights_transposed{ nullptr };
435     CLTensor _recurrent_to_cell_weights_transposed{ nullptr };
436     CLTensor _recurrent_to_output_weights_transposed{ nullptr };
437     CLTensor _recurrent_to_input_weights_transposed{ nullptr };
438     CLTensor _projection_weights_transposed{ nullptr };
439     CLTensor _input_to_input_eff_bias{ nullptr };
440     CLTensor _recurrent_to_input_eff_bias{ nullptr };
441     CLTensor _input_to_forget_eff_bias{ nullptr };
442     CLTensor _recurrent_to_forget_eff_bias{ nullptr };
443     CLTensor _input_to_cell_eff_bias{ nullptr };
444     CLTensor _recurrent_to_cell_eff_bias{ nullptr };
445     CLTensor _input_to_output_eff_bias{ nullptr };
446     CLTensor _recurrent_to_output_eff_bias{ nullptr };
447     CLTensor _projection_reduction_res{ nullptr };
448     CLTensor _projection_eff_bias{ nullptr };
449     CLTensor _mm_input_to_forget_res{ nullptr };
450     CLTensor _mm_recurrent_to_forget_res{ nullptr };
451     CLTensor _mul_cell_to_forget_res{ nullptr };
452     CLTensor _input_to_forget_outstage_res{ nullptr };
453     CLTensor _cell_to_forget_outstage_res{ nullptr };
454     CLTensor _recurrent_to_forget_outstage_res{ nullptr };
455     CLTensor _forget_gate{ nullptr };
456     CLTensor _mm_input_to_cell_res{ nullptr };
457     CLTensor _input_to_cell_outstage_res{ nullptr };
458     CLTensor _mm_recurrent_to_cell_res{ nullptr };
459     CLTensor _recurrent_to_cell_outstage_res{ nullptr };
460     CLTensor _cell_gate{ nullptr };
461     CLTensor _mul_input_cell_res{ nullptr };
462     CLTensor _mm_input_to_input_res{ nullptr };
463     CLTensor _input_to_input_outstage_res{ nullptr };
464     CLTensor _mm_recurrent_to_input_res{ nullptr };
465     CLTensor _mul_cell_to_input_res{ nullptr };
466     CLTensor _cell_to_input_outstage_res{ nullptr };
467     CLTensor _recurrent_to_input_outstage_res{ nullptr };
468     CLTensor _input_gate{ nullptr };
469     CLTensor _mm_input_to_output_res{ nullptr };
470     CLTensor _input_to_output_outstage_res{ nullptr };
471     CLTensor _mm_recurrent_to_output_res{ nullptr };
472     CLTensor _mul_cell_to_output_res{ nullptr };
473     CLTensor _cell_to_output_outstage_res{ nullptr };
474     CLTensor _recurrent_to_output_outstage_res{ nullptr };
475     CLTensor _output_gate{ nullptr };
476     CLTensor _hidden_mul_res{ nullptr };
477     CLTensor _hidden_gate{ nullptr };
478     CLTensor _mm_projection_res{ nullptr };
479     CLTensor _projection_outstage_res{ nullptr };
480     CLTensor _projection_out_res{ nullptr };
481     CLTensor _projection_accumulate_res{ nullptr };
482     CLTensor _ones{ nullptr };
483     std::array<CLTensor, _layer_norm_count> _layer_norm_output{ {} };
484 
get_layer_norm_output(LayerNormGate g)485     inline CLTensor &get_layer_norm_output(LayerNormGate g)
486     {
487         return _layer_norm_output[getGateIndex(g)];
488     }
489 
490     bool _is_prepared{ false };
491     bool _has_cifg{ false };
492     bool _has_cell_clipping{ false };
493     bool _has_projection{ false };
494     bool _has_projection_clipping{ false };
495     bool _has_peephole{ false };
496     bool _has_layer_norm{ false };
497     bool _projection_tensor_copy_required{ false };
498 };
499 } // namespace arm_compute
500 #endif /* ARM_COMPUTE_CLQLSTMLAYER_H */
501