xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/runtime/CL/functions/CLLSTMLayer.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CLLSTMLAYER_H
25 #define ARM_COMPUTE_CLLSTMLAYER_H
26 
27 #include "arm_compute/runtime/IFunction.h"
28 
29 #include "arm_compute/core/Types.h"
30 #include "arm_compute/runtime/CL/CLTensor.h"
31 #include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
32 #include "arm_compute/runtime/CL/functions/CLConcatenateLayer.h"
33 #include "arm_compute/runtime/CL/functions/CLCopy.h"
34 #include "arm_compute/runtime/CL/functions/CLElementwiseOperations.h"
35 #include "arm_compute/runtime/CL/functions/CLFill.h"
36 #include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
37 #include "arm_compute/runtime/CL/functions/CLGEMM.h"
38 #include "arm_compute/runtime/CL/functions/CLMeanStdDevNormalizationLayer.h"
39 #include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
40 #include "arm_compute/runtime/IMemoryManager.h"
41 #include "arm_compute/runtime/MemoryGroup.h"
42 #include "arm_compute/runtime/common/LSTMParams.h"
43 
44 #include <memory>
45 
46 namespace arm_compute
47 {
48 class CLCompileContext;
49 class ICLTensor;
50 namespace opencl
51 {
52 namespace kernels
53 {
54 class ClTransposeKernel;
55 }
56 }
57 
58 /** This function performs a single time step in a Long Short-Term Memory (LSTM) layer.
59  *
60  */
61 class CLLSTMLayer : public IFunction
62 {
63 public:
64     /** Default constructor */
65     CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
66     /** Prevent instances of this class from being copied */
67     CLLSTMLayer(const CLLSTMLayer &) = delete;
68     /** Prevent instances of this class from being copied */
69     CLLSTMLayer &operator=(const CLLSTMLayer &) = delete;
70     /** Prevent instances of this class to be moved */
71     CLLSTMLayer(CLLSTMLayer &&) = delete;
72     /** Prevent instances of this class to be moved */
73     CLLSTMLayer &operator=(CLLSTMLayer &&) = delete;
74     /** Default destructor */
75     ~CLLSTMLayer();
76     /** Initialize function's tensors.
77      *
78      * Valid data layouts:
79      * - All
80      *
81      * Valid data type configurations:
82      * |src0 - src13 | dst0 - dst3 |
83      * |:------------|:------------|
84      * |F16          |F16          |
85      * |F32          |F32          |
86      *
87      * @param[in]  input                       Source tensor. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: F16/F32.
88      * @param[in]  input_to_forget_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
89      * @param[in]  input_to_cell_weights       2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
90      * @param[in]  input_to_output_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
91      * @param[in]  recurrent_to_forget_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
92      * @param[in]  recurrent_to_cell_weights   2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
93      * @param[in]  recurrent_to_output_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
94      * @param[in]  forget_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
95      * @param[in]  cell_bias                   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
96      * @param[in]  output_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
97      * @param[in]  output_state_in             2D weights tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
98      * @param[in]  cell_state_in               2D tensor with dimensions [num_units, batch_size]. Data type supported: Same as @p input.
99      * @param[out] scratch_buffer              2D tensor with dimensions [num_units * 4, batch_size] with CIFG or [num_units * 3, batch_size] without CIGF. Data type supported: Same as @p input.
100      * @param[out] output_state_out            2D weights tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
101      * @param[out] cell_state_out              2D tensor with dimensions [num_units, batch_size]. Data type supported: Same as @p input.
102      * @param[out] output                      Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].
103      *                                         Data types supported: Same as @p input.
104      * @param[in]  lstm_params                 Weights tensors used in peephole optimization:
105      *                                         input_to_input_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
106      *                                         recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
107      *                                         cell_to_input_weights      1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input.
108      *                                         cell_to_forget_weights     1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
109      *                                         cell_to_output_weights     1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
110      *                                         input_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input
111      *                                         projection_weights         2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
112      *                                         projection_bias            1D weights tensor with dimensions [output_size]. Data type supported: Same as @p input.
113      *                                         input_layer_norm_weights   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
114      *                                         forget_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
115      *                                         cell_layer_norm_weights    1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
116      *                                         output_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
117      * @param[in]  activation_info             Contains activation information described in @ref ActivationLayerInfo.
118      * @param[in]  cell_threshold              (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
119      *                                         If set to 0.0f then clipping is disabled.
120      * @param[in]  projection_threshold        (Optional) The clipping threshold for the output from the projection layer, such that values are bound within [-proj_clip, proj_clip].
121      *                                         If set to 0.0f then clipping is disabled.
122      */
123     void configure(const ICLTensor *input,
124                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
125                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
126                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
127                    const ICLTensor *output_state_in, ICLTensor *cell_state_in,
128                    ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
129                    const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
130     /** Initialize function's tensors.
131      *
132      * @param[in]  compile_context             The compile context to be used.
133      * @param[in]  input                       Source tensor. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: F16/F32.
134      * @param[in]  input_to_forget_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
135      * @param[in]  input_to_cell_weights       2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
136      * @param[in]  input_to_output_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
137      * @param[in]  recurrent_to_forget_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
138      * @param[in]  recurrent_to_cell_weights   2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
139      * @param[in]  recurrent_to_output_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
140      * @param[in]  forget_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
141      * @param[in]  cell_bias                   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
142      * @param[in]  output_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
143      * @param[in]  output_state_in             2D weights tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
144      * @param[in]  cell_state_in               2D tensor with dimensions [num_units, batch_size]. Data type supported: Same as @p input.
145      * @param[out] scratch_buffer              2D tensor with dimensions [num_units * 4, batch_size] with CIFG or [num_units * 3, batch_size] without CIGF. Data type supported: Same as @p input.
146      * @param[out] output_state_out            2D weights tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
147      * @param[out] cell_state_out              2D tensor with dimensions [num_units, batch_size]. Data type supported: Same as @p input.
148      * @param[out] output                      Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].
149      *                                         Data types supported: Same as @p input.
150      * @param[in]  lstm_params                 Weights tensors used in peephole optimization:
151      *                                         input_to_input_weights     2D weights tensor with dimensions [input_size, num_units]. Data type supported: Same as @p input.
152      *                                         recurrent_to_input_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
153      *                                         cell_to_input_weights      1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input.
154      *                                         cell_to_forget_weights     1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
155      *                                         cell_to_output_weights     1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
156      *                                         input_gate_bias            1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input
157      *                                         projection_weights         2D weights tensor with dimensions [output_size, num_units]. Data type supported: Same as @p input.
158      *                                         projection_bias            1D weights tensor with dimensions [output_size]. Data type supported: Same as @p input.
159      *                                         input_layer_norm_weights   1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
160      *                                         forget_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
161      *                                         cell_layer_norm_weights    1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
162      *                                         output_layer_norm_weights  1D weights tensor with dimensions [num_units]. Data type supported: Same as @p input.
163      * @param[in]  activation_info             Contains activation information described in @ref ActivationLayerInfo.
164      * @param[in]  cell_threshold              (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
165      *                                         If set to 0.0f then clipping is disabled.
166      * @param[in]  projection_threshold        (Optional) The clipping threshold for the output from the projection layer, such that values are bound within [-proj_clip, proj_clip].
167      *                                         If set to 0.0f then clipping is disabled.
168      */
169     void configure(const CLCompileContext &compile_context, const ICLTensor *input,
170                    const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
171                    const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
172                    const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
173                    const ICLTensor *output_state_in, ICLTensor *cell_state_in,
174                    ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
175                    const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
176 
177     /** Static function to check if given info will lead to a valid configuration of @ref CLLSTMLayer
178      *
179      * @param[in] input                       Source tensor info. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: F16/F32.
180      * @param[in] input_to_forget_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: Same as @p input.
181      * @param[in] input_to_cell_weights       2D weights tensor info with dimensions [input_size, num_units]. Data type supported: Same as @p input.
182      * @param[in] input_to_output_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: Same as @p input.
183      * @param[in] recurrent_to_forget_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: Same as @p input.
184      * @param[in] recurrent_to_cell_weights   2D weights tensor info with dimensions [output_size, num_units]. Data type supported: Same as @p input.
185      * @param[in] recurrent_to_output_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: Same as @p input.
186      * @param[in] forget_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
187      * @param[in] cell_bias                   1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
188      * @param[in] output_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
189      * @param[in] output_state_in             2D weights tensor info with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
190      * @param[in] cell_state_in               2D tensor info with dimensions [num_units, batch_size]. Data type supported: Same as @p input.
191      * @param[in] scratch_buffer              2D tensor info with dimensions [num_units * 4, batch_size] with CIFG or [num_units * 3, batch_size] without CIGF.
192      *                                        Data type supported: Same as @p input.
193      * @param[in] output_state_out            2D weights tensor info with dimensions [output_size, batch_size]. Data type supported: Same as @p input.
194      * @param[in] cell_state_out              2D tensor info with dimensions [num_units, batch_size]. Data type supported: Same as @p input.
195      * @param[in] output                      Destination tensor info. Output is a 2D tensor with dimensions [output_size, batch_size]. Data types supported: Same as @p input.
196      * @param[in] lstm_params                 Weights tensors info used in peephole optimization:
197      *                                        input_to_input_weights     2D weights tensor info with dimensions [input_size, num_units]. Data type supported: Same as @p input.
198      *                                        recurrent_to_input_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: Same as @p input.
199      *                                        cell_to_input_weights      1D weights tensor info with dimensions [num_units]. Can be nullptr. Data type supported: Same as @p input.
200      *                                        cell_to_forget_weights     1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
201      *                                        cell_to_output_weights     1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
202      *                                        input_gate_bias            1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input
203      *                                        projection_weights         2D weights tensor info with dimensions [output_size, num_units]. Data type supported: Same as @p input.
204      *                                        projection_bias            1D weights tensor info with dimensions [output_size]. Data type supported: Same as @p input.
205      *                                        input_layer_norm_weights   1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
206      *                                        forget_layer_norm_weights  1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
207      *                                        cell_layer_norm_weights    1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
208      *                                        output_layer_norm_weights  1D weights tensor info with dimensions [num_units]. Data type supported: Same as @p input.
209      * @param[in] activation_info             Contains activation information described in @ref ActivationLayerInfo.
210      * @param[in] cell_threshold              (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip].
211      *                                        If set to 0.0f then clipping is disabled.
212      * @param[in] projection_threshold        (Optional) The clipping threshold for the output from the projection layer, such that values are bound within [-proj_clip, proj_clip].
213      *                                        If set to 0.0f then clipping is disabled.
214      *
215      * @return a status
216      */
217     static Status validate(const ITensorInfo *input,
218                            const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
219                            const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
220                            const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
221                            const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
222                            const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
223                            const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
224 
225     // Inherited methods overridden:
226     void run() override;
227     void prepare() override;
228 
229 private:
230     MemoryGroup                                         _memory_group;
231     CLFullyConnectedLayer                               _fully_connected_input_gate;
232     CLArithmeticAddition                                _accum_input_gate1;
233     CLArithmeticSubtraction                             _subtract_input_gate;
234     CLPixelWiseMultiplication                           _pixelwise_mul_input_gate;
235     CLActivationLayer                                   _activation_input_gate;
236     CLFullyConnectedLayer                               _fully_connected_forget_gate;
237     CLArithmeticAddition                                _accum_forget_gate1;
238     CLPixelWiseMultiplication                           _pixelwise_mul_forget_gate;
239     CLActivationLayer                                   _activation_forget_gate;
240     CLFullyConnectedLayer                               _fully_connected_cell_state;
241     CLGEMM                                              _gemm_cell_state1;
242     std::unique_ptr<opencl::kernels::ClTransposeKernel> _transpose_cell_state;
243     CLArithmeticAddition                                _accum_cell_state1;
244     CLArithmeticAddition                                _accum_cell_state2;
245     CLPixelWiseMultiplication                           _pixelwise_mul_cell_state1;
246     CLActivationLayer                                   _activation_cell_state;
247     CLActivationLayer                                   _cell_clip;
248     CLPixelWiseMultiplication                           _pixelwise_mul_cell_state2;
249     CLFullyConnectedLayer                               _fully_connected_output;
250     CLPixelWiseMultiplication                           _pixelwise_mul_output_state1;
251     CLArithmeticAddition                                _accum_output1;
252     CLActivationLayer                                   _activation_output;
253     CLActivationLayer                                   _activation_output_state;
254     CLPixelWiseMultiplication                           _pixelwise_mul_output_state2;
255     CLFullyConnectedLayer                               _fully_connected_output_state;
256     CLActivationLayer                                   _projection_clip;
257     CLCopy                                              _copy_cell_state;
258     CLCopy                                              _copy_output;
259     CLConcatenateLayer                                  _concat_scratch_buffer;
260     CLConcatenateLayer                                  _concat_inputs_forget_gate;
261     CLConcatenateLayer                                  _concat_weights_forget_gate;
262     CLConcatenateLayer                                  _concat_weights_input_gate;
263     CLConcatenateLayer                                  _concat_weights_output;
264     CLFill                                              _ones_fill;
265     CLMeanStdDevNormalizationLayer                      _mean_std_norm_input_gate;
266     CLPixelWiseMultiplication                           _pixelwise_mul_input_gate_coeff;
267     CLArithmeticAddition                                _accum_input_gate_bias;
268     CLMeanStdDevNormalizationLayer                      _mean_std_norm_forget_gate;
269     CLPixelWiseMultiplication                           _pixelwise_mul_forget_gate_coeff;
270     CLArithmeticAddition                                _accum_forget_gate_bias;
271     CLMeanStdDevNormalizationLayer                      _mean_std_norm_cell_gate;
272     CLPixelWiseMultiplication                           _pixelwise_mul_cell_gate_coeff;
273     CLArithmeticAddition                                _accum_cell_gate_bias;
274     CLMeanStdDevNormalizationLayer                      _mean_std_norm_output_gate;
275     CLPixelWiseMultiplication                           _pixelwise_mul_output_gate_coeff;
276     CLArithmeticAddition                                _accum_output_gate_bias;
277     CLTensor                                            _input_gate_out1;
278     CLTensor                                            _input_gate_out2;
279     CLTensor                                            _input_gate_out3;
280     CLTensor                                            _input_gate_out4;
281     CLTensor                                            _forget_gate_out1;
282     CLTensor                                            _forget_gate_out2;
283     CLTensor                                            _forget_gate_out3;
284     CLTensor                                            _forget_gate_out4;
285     CLTensor                                            _forget_gate_out5;
286     CLTensor                                            _forget_gate_out6;
287     CLTensor                                            _cell_state_out1;
288     CLTensor                                            _cell_state_out2;
289     CLTensor                                            _cell_state_out3;
290     CLTensor                                            _cell_state_out4;
291     CLTensor                                            _cell_state_out5;
292     CLTensor                                            _output1;
293     CLTensor                                            _output2;
294     CLTensor                                            _output3;
295     CLTensor                                            _output4;
296     CLTensor                                            _cell_state_activation;
297     CLTensor                                            _output_state1;
298     CLTensor                                            _ones;
299     CLTensor                                            _input_layer_norm_out1;
300     CLTensor                                            _input_layer_norm_out2;
301     CLTensor                                            _forget_layer_norm_out1;
302     CLTensor                                            _forget_layer_norm_out2;
303     CLTensor                                            _cell_layer_norm_out1;
304     CLTensor                                            _cell_layer_norm_out2;
305     CLTensor                                            _output_layer_norm_out1;
306     CLTensor                                            _output_layer_norm_out2;
307     bool                                                _run_peephole_opt;
308     bool                                                _run_cifg_opt;
309     bool                                                _perform_cell_clipping;
310     bool                                                _has_projection_weights;
311     bool                                                _perform_projection_clipping;
312     bool                                                _is_prepared;
313     bool                                                _is_layer_norm_lstm;
314     const ICLTensor                                    *_recurrent_to_cell_weights{ nullptr };
315 };
316 } // namespace arm_compute
317 #endif /* ARM_COMPUTE_CLLSTMLAYER_H */
318