xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/CL/functions/CLLSTMLayer.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/runtime/CL/functions/CLLSTMLayer.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/Utils.h"
27*c217d954SCole Faust #include "arm_compute/core/Validate.h"
28*c217d954SCole Faust #include "arm_compute/core/utils/misc/InfoHelpers.h"
29*c217d954SCole Faust #include "arm_compute/core/utils/misc/ShapeCalculator.h"
30*c217d954SCole Faust #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31*c217d954SCole Faust #include "arm_compute/runtime/CL/CLScheduler.h"
32*c217d954SCole Faust #include "src/core/CL/kernels/CLFillBorderKernel.h"
33*c217d954SCole Faust #include "src/gpu/cl/kernels/ClTransposeKernel.h"
34*c217d954SCole Faust 
35*c217d954SCole Faust #include "src/common/utils/Log.h"
36*c217d954SCole Faust 
37*c217d954SCole Faust namespace arm_compute
38*c217d954SCole Faust {
39*c217d954SCole Faust using namespace arm_compute::misc::shape_calculator;
40*c217d954SCole Faust using namespace arm_compute::utils::info_helpers;
41*c217d954SCole Faust 
CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)42*c217d954SCole Faust CLLSTMLayer::CLLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager)
43*c217d954SCole Faust     : _memory_group(std::move(memory_manager)), _fully_connected_input_gate(), _accum_input_gate1(), _subtract_input_gate(), _pixelwise_mul_input_gate(), _activation_input_gate(),
44*c217d954SCole Faust       _fully_connected_forget_gate(), _accum_forget_gate1(), _pixelwise_mul_forget_gate(), _activation_forget_gate(), _fully_connected_cell_state(), _gemm_cell_state1(),
45*c217d954SCole Faust       _transpose_cell_state(std::make_unique<opencl::kernels::ClTransposeKernel>()), _accum_cell_state1(), _accum_cell_state2(), _pixelwise_mul_cell_state1(), _activation_cell_state(), _cell_clip(),
46*c217d954SCole Faust       _pixelwise_mul_cell_state2(), _fully_connected_output(), _pixelwise_mul_output_state1(), _accum_output1(), _activation_output(), _activation_output_state(), _pixelwise_mul_output_state2(),
47*c217d954SCole Faust       _fully_connected_output_state(), _projection_clip(), _copy_cell_state(), _copy_output(), _concat_scratch_buffer(), _concat_inputs_forget_gate(), _concat_weights_forget_gate(),
48*c217d954SCole Faust       _concat_weights_input_gate(), _concat_weights_output(), _ones_fill(), _mean_std_norm_input_gate(), _pixelwise_mul_input_gate_coeff(), _accum_input_gate_bias(), _mean_std_norm_forget_gate(),
49*c217d954SCole Faust       _pixelwise_mul_forget_gate_coeff(), _accum_forget_gate_bias(), _mean_std_norm_cell_gate(), _pixelwise_mul_cell_gate_coeff(), _accum_cell_gate_bias(), _mean_std_norm_output_gate(),
50*c217d954SCole Faust       _pixelwise_mul_output_gate_coeff(), _accum_output_gate_bias(), _input_gate_out1(), _input_gate_out2(), _input_gate_out3(), _input_gate_out4(), _forget_gate_out1(), _forget_gate_out2(),
51*c217d954SCole Faust       _forget_gate_out3(), _forget_gate_out4(), _forget_gate_out5(), _forget_gate_out6(), _cell_state_out1(), _cell_state_out2(), _cell_state_out3(), _cell_state_out4(), _cell_state_out5(), _output1(),
52*c217d954SCole Faust       _output2(), _output3(), _output4(), _cell_state_activation(), _output_state1(), _ones(), _input_layer_norm_out1(), _input_layer_norm_out2(), _forget_layer_norm_out1(), _forget_layer_norm_out2(),
53*c217d954SCole Faust       _cell_layer_norm_out1(), _cell_layer_norm_out2(), _output_layer_norm_out1(), _output_layer_norm_out2(), _run_peephole_opt(false), _run_cifg_opt(false), _perform_cell_clipping(false),
54*c217d954SCole Faust       _has_projection_weights(false), _perform_projection_clipping(false), _is_prepared(false), _is_layer_norm_lstm(false)
55*c217d954SCole Faust {
56*c217d954SCole Faust }
57*c217d954SCole Faust 
58*c217d954SCole Faust CLLSTMLayer::~CLLSTMLayer() = default;
59*c217d954SCole Faust 
configure(const ICLTensor * input,const ICLTensor * input_to_forget_weights,const ICLTensor * input_to_cell_weights,const ICLTensor * input_to_output_weights,const ICLTensor * recurrent_to_forget_weights,const ICLTensor * recurrent_to_cell_weights,const ICLTensor * recurrent_to_output_weights,const ICLTensor * forget_gate_bias,const ICLTensor * cell_bias,const ICLTensor * output_gate_bias,const ICLTensor * output_state_in,ICLTensor * cell_state_in,ICLTensor * scratch_buffer,ICLTensor * output_state_out,ICLTensor * cell_state_out,ICLTensor * output,const LSTMParams<ICLTensor> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)60*c217d954SCole Faust void CLLSTMLayer::configure(const ICLTensor *input,
61*c217d954SCole Faust                             const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
62*c217d954SCole Faust                             const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
63*c217d954SCole Faust                             const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
64*c217d954SCole Faust                             const ICLTensor *output_state_in, ICLTensor *cell_state_in,
65*c217d954SCole Faust                             ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
66*c217d954SCole Faust                             const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
67*c217d954SCole Faust {
68*c217d954SCole Faust     configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
69*c217d954SCole Faust               recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out, output, lstm_params, activation_info,
70*c217d954SCole Faust               cell_threshold, projection_threshold);
71*c217d954SCole Faust }
72*c217d954SCole Faust 
configure(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * input_to_forget_weights,const ICLTensor * input_to_cell_weights,const ICLTensor * input_to_output_weights,const ICLTensor * recurrent_to_forget_weights,const ICLTensor * recurrent_to_cell_weights,const ICLTensor * recurrent_to_output_weights,const ICLTensor * forget_gate_bias,const ICLTensor * cell_bias,const ICLTensor * output_gate_bias,const ICLTensor * output_state_in,ICLTensor * cell_state_in,ICLTensor * scratch_buffer,ICLTensor * output_state_out,ICLTensor * cell_state_out,ICLTensor * output,const LSTMParams<ICLTensor> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)73*c217d954SCole Faust void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
74*c217d954SCole Faust                             const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
75*c217d954SCole Faust                             const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
76*c217d954SCole Faust                             const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
77*c217d954SCole Faust                             const ICLTensor *output_state_in, ICLTensor *cell_state_in,
78*c217d954SCole Faust                             ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
79*c217d954SCole Faust                             const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
80*c217d954SCole Faust {
81*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(input,
82*c217d954SCole Faust                                  input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
83*c217d954SCole Faust                                  recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
84*c217d954SCole Faust                                  forget_gate_bias, cell_bias, output_gate_bias,
85*c217d954SCole Faust                                  output_state_in, cell_state_in,
86*c217d954SCole Faust                                  scratch_buffer, output_state_out, cell_state_out, output);
87*c217d954SCole Faust 
88*c217d954SCole Faust     ARM_COMPUTE_LOG_PARAMS(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
89*c217d954SCole Faust                            recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state_in, cell_state_in, scratch_buffer, output_state_out, cell_state_out,
90*c217d954SCole Faust                            output, lstm_params, activation_info, cell_threshold, projection_threshold);
91*c217d954SCole Faust 
92*c217d954SCole Faust     _is_layer_norm_lstm = lstm_params.use_layer_norm();
93*c217d954SCole Faust 
94*c217d954SCole Faust     // Set lstm parameters
95*c217d954SCole Faust     LSTMParams<ITensorInfo> lstm_params_info{};
96*c217d954SCole Faust     build_lstm_params_tensor_info(lstm_params, &lstm_params_info);
97*c217d954SCole Faust 
98*c217d954SCole Faust     // Validate
99*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(CLLSTMLayer::validate(input->info(), input_to_forget_weights->info(),
100*c217d954SCole Faust                                                      input_to_cell_weights->info(), input_to_output_weights->info(),
101*c217d954SCole Faust                                                      recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
102*c217d954SCole Faust                                                      forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
103*c217d954SCole Faust                                                      output_state_in->info(), cell_state_in->info(),
104*c217d954SCole Faust                                                      scratch_buffer->info(), output_state_out->info(), cell_state_out->info(), output->info(),
105*c217d954SCole Faust                                                      lstm_params_info, activation_info, cell_threshold, projection_threshold));
106*c217d954SCole Faust 
107*c217d954SCole Faust     const TensorShape cell_state_shape = cell_state_in->info()->tensor_shape();
108*c217d954SCole Faust     // Configure block that calculates the forget gate
109*c217d954SCole Faust     // forget_gate = Activation(input * input_to_forget_weights + output_state_in * recurrent_to_forget_weights + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias)
110*c217d954SCole Faust     // We optimize this as follows:
111*c217d954SCole Faust     // forget_gate = Activation( (input,output_state_in) * (input_to_forget_weights,recurrent_to_forget_weights) + PixelWiseMul(cell_state, cell_to_forget_weights) + forget_gate_bias
112*c217d954SCole Faust     _forget_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
113*c217d954SCole Faust     _forget_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
114*c217d954SCole Faust     _forget_gate_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
115*c217d954SCole Faust 
116*c217d954SCole Faust     std::vector<const ICLTensor *> inputs_vector;
117*c217d954SCole Faust     inputs_vector.emplace_back(input);
118*c217d954SCole Faust     inputs_vector.emplace_back(output_state_in);
119*c217d954SCole Faust     const TensorShape concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
120*c217d954SCole Faust     _forget_gate_out2.allocator()->init(TensorInfo(concat_shape, 1, input->info()->data_type()));
121*c217d954SCole Faust 
122*c217d954SCole Faust     _memory_group.manage(&_forget_gate_out2);
123*c217d954SCole Faust     _concat_inputs_forget_gate.configure(compile_context, inputs_vector, &_forget_gate_out2, Window::DimX);
124*c217d954SCole Faust 
125*c217d954SCole Faust     std::vector<const ICLTensor *> weights_vector;
126*c217d954SCole Faust 
127*c217d954SCole Faust     weights_vector.emplace_back(input_to_forget_weights);
128*c217d954SCole Faust     weights_vector.emplace_back(recurrent_to_forget_weights);
129*c217d954SCole Faust     const TensorShape weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(weights_vector, 0);
130*c217d954SCole Faust     _forget_gate_out6.allocator()->init(TensorInfo(weights_concat_shape, 1, input->info()->data_type()));
131*c217d954SCole Faust 
132*c217d954SCole Faust     _concat_weights_forget_gate.configure(compile_context, weights_vector, &_forget_gate_out6, Window::DimX);
133*c217d954SCole Faust 
134*c217d954SCole Faust     _memory_group.manage(&_forget_gate_out5);
135*c217d954SCole Faust     _fully_connected_forget_gate.configure(compile_context, &_forget_gate_out2, &_forget_gate_out6, (_is_layer_norm_lstm) ? nullptr : forget_gate_bias, &_forget_gate_out5);
136*c217d954SCole Faust     _memory_group.manage(&_forget_gate_out1);
137*c217d954SCole Faust     _memory_group.manage(&_forget_gate_out3);
138*c217d954SCole Faust     _forget_gate_out6.allocator()->allocate();
139*c217d954SCole Faust 
140*c217d954SCole Faust     CLTensor *forget_gate_out = &_forget_gate_out5;
141*c217d954SCole Faust     if(lstm_params.has_peephole_opt())
142*c217d954SCole Faust     {
143*c217d954SCole Faust         _forget_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
144*c217d954SCole Faust 
145*c217d954SCole Faust         _run_peephole_opt = true;
146*c217d954SCole Faust         _memory_group.manage(&_forget_gate_out4);
147*c217d954SCole Faust         _pixelwise_mul_forget_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_forget_weights(), &_forget_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
148*c217d954SCole Faust         _accum_forget_gate1.configure(compile_context, &_forget_gate_out5, &_forget_gate_out4, &_forget_gate_out3, ConvertPolicy::SATURATE);
149*c217d954SCole Faust         _forget_gate_out4.allocator()->allocate();
150*c217d954SCole Faust         _forget_gate_out5.allocator()->allocate();
151*c217d954SCole Faust         forget_gate_out = &_forget_gate_out3;
152*c217d954SCole Faust     }
153*c217d954SCole Faust     else
154*c217d954SCole Faust     {
155*c217d954SCole Faust         _forget_gate_out3.allocator()->allocate();
156*c217d954SCole Faust     }
157*c217d954SCole Faust     if(_is_layer_norm_lstm)
158*c217d954SCole Faust     {
159*c217d954SCole Faust         _forget_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
160*c217d954SCole Faust         _forget_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
161*c217d954SCole Faust         _memory_group.manage(&_forget_layer_norm_out1);
162*c217d954SCole Faust         _memory_group.manage(&_forget_layer_norm_out2);
163*c217d954SCole Faust         _mean_std_norm_forget_gate.configure(compile_context, forget_gate_out);
164*c217d954SCole Faust         _pixelwise_mul_forget_gate_coeff.configure(compile_context, forget_gate_out, lstm_params.forget_layer_norm_weights(), &_forget_layer_norm_out1, 1, ConvertPolicy::SATURATE,
165*c217d954SCole Faust                                                    RoundingPolicy::TO_NEAREST_EVEN);
166*c217d954SCole Faust         // forget_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
167*c217d954SCole Faust         forget_gate_out->allocator()->allocate();
168*c217d954SCole Faust         _accum_forget_gate_bias.configure(compile_context, &_forget_layer_norm_out1, forget_gate_bias, &_forget_layer_norm_out2, ConvertPolicy::SATURATE);
169*c217d954SCole Faust         _forget_layer_norm_out1.allocator()->allocate();
170*c217d954SCole Faust         forget_gate_out = &_forget_layer_norm_out2;
171*c217d954SCole Faust     }
172*c217d954SCole Faust     _activation_forget_gate.configure(compile_context, forget_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
173*c217d954SCole Faust 
174*c217d954SCole Faust     // Configure block that calculates the input gate
175*c217d954SCole Faust     // input_gate = Activation(input * input_to_input_weights + output_state * recurrent_to_input_weights + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
176*c217d954SCole Faust     // input_gate = 1 - forget_gate, with CIFG
177*c217d954SCole Faust     // We optimize this as follows:
178*c217d954SCole Faust     // input_gate = Activation((input,output_state) * (input_to_input_weights,recurrent_to_input_weights) + PixelWiseMul(cell_state, cell_to_input_weights) + input_gate_bias), without CIFG
179*c217d954SCole Faust     _input_gate_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
180*c217d954SCole Faust     CLTensor *input_gate_out = &_input_gate_out1;
181*c217d954SCole Faust     if(lstm_params.has_cifg_opt())
182*c217d954SCole Faust     {
183*c217d954SCole Faust         _memory_group.manage(&_input_gate_out1);
184*c217d954SCole Faust         _ones.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
185*c217d954SCole Faust         _ones_fill.configure(compile_context, &_ones, PixelValue(1, _ones.info()->data_type()));
186*c217d954SCole Faust         _subtract_input_gate.configure(compile_context, &_ones, forget_gate_out, &_input_gate_out1, ConvertPolicy::SATURATE);
187*c217d954SCole Faust         _ones.allocator()->allocate();
188*c217d954SCole Faust         _run_cifg_opt = true;
189*c217d954SCole Faust     }
190*c217d954SCole Faust     else
191*c217d954SCole Faust     {
192*c217d954SCole Faust         _input_gate_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
193*c217d954SCole Faust         _input_gate_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
194*c217d954SCole Faust 
195*c217d954SCole Faust         std::vector<const ICLTensor *> lstm_weights;
196*c217d954SCole Faust         lstm_weights.emplace_back(lstm_params.input_to_input_weights());
197*c217d954SCole Faust         lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
198*c217d954SCole Faust         TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
199*c217d954SCole Faust         _input_gate_out2.allocator()->init(TensorInfo(lstm_weights_concat_shape, 1, input->info()->data_type()));
200*c217d954SCole Faust 
201*c217d954SCole Faust         _concat_weights_input_gate.configure(compile_context, lstm_weights, &_input_gate_out2, Window::DimX);
202*c217d954SCole Faust 
203*c217d954SCole Faust         _memory_group.manage(&_input_gate_out1);
204*c217d954SCole Faust 
205*c217d954SCole Faust         _memory_group.manage(&_input_gate_out3);
206*c217d954SCole Faust         _fully_connected_input_gate.configure(compile_context, &_forget_gate_out2, &_input_gate_out2, (_is_layer_norm_lstm) ? nullptr : lstm_params.input_gate_bias(), &_input_gate_out3);
207*c217d954SCole Faust         _input_gate_out2.allocator()->allocate();
208*c217d954SCole Faust 
209*c217d954SCole Faust         input_gate_out = &_input_gate_out3;
210*c217d954SCole Faust         if(_run_peephole_opt)
211*c217d954SCole Faust         {
212*c217d954SCole Faust             _memory_group.manage(&_input_gate_out4);
213*c217d954SCole Faust             _pixelwise_mul_input_gate.configure(compile_context, cell_state_in, lstm_params.cell_to_input_weights(), &_input_gate_out4, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
214*c217d954SCole Faust             _accum_input_gate1.configure(compile_context, &_input_gate_out3, &_input_gate_out4, &_input_gate_out1, ConvertPolicy::SATURATE);
215*c217d954SCole Faust             _input_gate_out3.allocator()->allocate();
216*c217d954SCole Faust             _input_gate_out4.allocator()->allocate();
217*c217d954SCole Faust             input_gate_out = &_input_gate_out1;
218*c217d954SCole Faust         }
219*c217d954SCole Faust         else
220*c217d954SCole Faust         {
221*c217d954SCole Faust             _input_gate_out1.allocator()->allocate();
222*c217d954SCole Faust         }
223*c217d954SCole Faust 
224*c217d954SCole Faust         if(_is_layer_norm_lstm)
225*c217d954SCole Faust         {
226*c217d954SCole Faust             _input_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
227*c217d954SCole Faust             _input_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
228*c217d954SCole Faust             _memory_group.manage(&_input_layer_norm_out1);
229*c217d954SCole Faust             _memory_group.manage(&_input_layer_norm_out2);
230*c217d954SCole Faust             _mean_std_norm_input_gate.configure(compile_context, input_gate_out);
231*c217d954SCole Faust             _pixelwise_mul_input_gate_coeff.configure(compile_context, input_gate_out, lstm_params.input_layer_norm_weights(), &_input_layer_norm_out1, 1, ConvertPolicy::SATURATE,
232*c217d954SCole Faust                                                       RoundingPolicy::TO_NEAREST_EVEN);
233*c217d954SCole Faust             // input_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
234*c217d954SCole Faust             input_gate_out->allocator()->allocate();
235*c217d954SCole Faust             _accum_input_gate_bias.configure(compile_context, &_input_layer_norm_out1, lstm_params.input_gate_bias(), &_input_layer_norm_out2, ConvertPolicy::SATURATE);
236*c217d954SCole Faust             _input_layer_norm_out1.allocator()->allocate();
237*c217d954SCole Faust             input_gate_out = &_input_layer_norm_out2;
238*c217d954SCole Faust         }
239*c217d954SCole Faust         _activation_input_gate.configure(compile_context, input_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
240*c217d954SCole Faust     }
241*c217d954SCole Faust 
242*c217d954SCole Faust     // Configure block that calculates the cell state
243*c217d954SCole Faust     // cell_state = Clip((PixelwiseMul(input_gate, Activation(input * input_to_cell_weights + output_state_in * recurrent_to_cell_weights + cell_bias)) + PixelwiseMul(forget_gate, cell_state)), cell_threshold)
244*c217d954SCole Faust     TensorShape cell_state1_shape = compute_transposed_shape(*recurrent_to_output_weights->info());
245*c217d954SCole Faust     _cell_state_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
246*c217d954SCole Faust     _cell_state_out2.allocator()->init(TensorInfo(cell_state1_shape, 1, input->info()->data_type()));
247*c217d954SCole Faust     _cell_state_out3.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
248*c217d954SCole Faust     _cell_state_out4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
249*c217d954SCole Faust     _cell_state_out5.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
250*c217d954SCole Faust 
251*c217d954SCole Faust     _memory_group.manage(&_cell_state_out1);
252*c217d954SCole Faust     _fully_connected_cell_state.configure(compile_context, input, input_to_cell_weights, (_is_layer_norm_lstm) ? nullptr : cell_bias, &_cell_state_out1);
253*c217d954SCole Faust     _memory_group.manage(&_cell_state_out2);
254*c217d954SCole Faust     _transpose_cell_state->configure(compile_context, recurrent_to_cell_weights->info(), _cell_state_out2.info());
255*c217d954SCole Faust     _recurrent_to_cell_weights = recurrent_to_cell_weights;
256*c217d954SCole Faust     _memory_group.manage(&_cell_state_out3);
257*c217d954SCole Faust     _gemm_cell_state1.configure(compile_context, output_state_in, &_cell_state_out2, nullptr, &_cell_state_out3, 1.f, 0.f);
258*c217d954SCole Faust     _cell_state_out2.allocator()->allocate();
259*c217d954SCole Faust     _memory_group.manage(&_cell_state_out4);
260*c217d954SCole Faust     _accum_cell_state1.configure(compile_context, &_cell_state_out1, &_cell_state_out3, &_cell_state_out4, ConvertPolicy::SATURATE);
261*c217d954SCole Faust     CLTensor *cell_state_out_ptr = &_cell_state_out4;
262*c217d954SCole Faust     if(_is_layer_norm_lstm)
263*c217d954SCole Faust     {
264*c217d954SCole Faust         _cell_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
265*c217d954SCole Faust         _cell_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
266*c217d954SCole Faust         _memory_group.manage(&_cell_layer_norm_out1);
267*c217d954SCole Faust         _memory_group.manage(&_cell_layer_norm_out2);
268*c217d954SCole Faust         _mean_std_norm_cell_gate.configure(compile_context, cell_state_out_ptr);
269*c217d954SCole Faust         _pixelwise_mul_cell_gate_coeff.configure(compile_context, cell_state_out_ptr, lstm_params.cell_layer_norm_weights(), &_cell_layer_norm_out1, 1, ConvertPolicy::SATURATE,
270*c217d954SCole Faust                                                  RoundingPolicy::TO_NEAREST_EVEN);
271*c217d954SCole Faust         // cell_state_out_ptr is going to be reassigned, so allocate the tensor that it was assigned to before
272*c217d954SCole Faust         cell_state_out_ptr->allocator()->allocate();
273*c217d954SCole Faust         _accum_cell_gate_bias.configure(compile_context, &_cell_layer_norm_out1, cell_bias, &_cell_layer_norm_out2, ConvertPolicy::SATURATE);
274*c217d954SCole Faust         _cell_layer_norm_out1.allocator()->allocate();
275*c217d954SCole Faust         cell_state_out_ptr = &_cell_layer_norm_out2;
276*c217d954SCole Faust     }
277*c217d954SCole Faust     _activation_cell_state.configure(compile_context, cell_state_out_ptr, nullptr, activation_info);
278*c217d954SCole Faust     _memory_group.manage(&_cell_state_out5);
279*c217d954SCole Faust     _pixelwise_mul_cell_state1.configure(compile_context, cell_state_out_ptr, input_gate_out, &_cell_state_out5, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
280*c217d954SCole Faust     cell_state_out_ptr->allocator()->allocate();
281*c217d954SCole Faust     _pixelwise_mul_cell_state2.configure(compile_context, forget_gate_out, cell_state_in, &_cell_state_out3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
282*c217d954SCole Faust     _accum_cell_state2.configure(compile_context, &_cell_state_out5, &_cell_state_out3, &_cell_state_out1, ConvertPolicy::SATURATE);
283*c217d954SCole Faust     _cell_state_out3.allocator()->allocate();
284*c217d954SCole Faust     _cell_state_out5.allocator()->allocate();
285*c217d954SCole Faust     // Perform clipping
286*c217d954SCole Faust     if(cell_threshold != 0.f)
287*c217d954SCole Faust     {
288*c217d954SCole Faust         _perform_cell_clipping = true;
289*c217d954SCole Faust         _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
290*c217d954SCole Faust     }
291*c217d954SCole Faust 
292*c217d954SCole Faust     // Configure block that calculates the output
293*c217d954SCole Faust     // output_state_out = Activation(input * input_to_output_weights + output_state_in * recurrent_to_output_weights + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
294*c217d954SCole Faust     // We optimize this as follows:
295*c217d954SCole Faust     // output_state_out = Activation( (input,output_state_in) * (input_to_output_weights, recurrent_to_output_weights) + PixelWiseMul(cell_state, cell_to_output_weights) + output_gate_bias)
296*c217d954SCole Faust     _output1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
297*c217d954SCole Faust     _output4.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
298*c217d954SCole Faust     std::vector<const ICLTensor *> in_out_weights;
299*c217d954SCole Faust     in_out_weights.emplace_back(input_to_output_weights);
300*c217d954SCole Faust     in_out_weights.emplace_back(recurrent_to_output_weights);
301*c217d954SCole Faust     TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
302*c217d954SCole Faust     _output2.allocator()->init(TensorInfo(in_out_weights_concat_shape, 1, input->info()->data_type()));
303*c217d954SCole Faust 
304*c217d954SCole Faust     _concat_weights_output.configure(compile_context, in_out_weights, &_output2, Window::DimX);
305*c217d954SCole Faust 
306*c217d954SCole Faust     _memory_group.manage(&_output1);
307*c217d954SCole Faust     _memory_group.manage(&_output4);
308*c217d954SCole Faust 
309*c217d954SCole Faust     _fully_connected_output.configure(compile_context, &_forget_gate_out2, &_output2, (_is_layer_norm_lstm) ? nullptr : output_gate_bias, &_output4);
310*c217d954SCole Faust 
311*c217d954SCole Faust     _output2.allocator()->allocate();
312*c217d954SCole Faust     _forget_gate_out2.allocator()->allocate();
313*c217d954SCole Faust 
314*c217d954SCole Faust     CLTensor *output_gate_out = &_output4;
315*c217d954SCole Faust     if(lstm_params.has_peephole_opt())
316*c217d954SCole Faust     {
317*c217d954SCole Faust         _output3.allocator()->init(TensorInfo(_cell_state_out1.info()->tensor_shape(), 1, input->info()->data_type()));
318*c217d954SCole Faust 
319*c217d954SCole Faust         _memory_group.manage(&_output3);
320*c217d954SCole Faust         _pixelwise_mul_output_state1.configure(compile_context, &_cell_state_out1, lstm_params.cell_to_output_weights(), &_output3, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
321*c217d954SCole Faust         _accum_output1.configure(compile_context, &_output4, &_output3, &_output1, ConvertPolicy::SATURATE);
322*c217d954SCole Faust         _output4.allocator()->allocate();
323*c217d954SCole Faust         output_gate_out = &_output1;
324*c217d954SCole Faust 
325*c217d954SCole Faust         // Allocate intermediate buffers
326*c217d954SCole Faust         _output3.allocator()->allocate();
327*c217d954SCole Faust     }
328*c217d954SCole Faust     else
329*c217d954SCole Faust     {
330*c217d954SCole Faust         _output1.allocator()->allocate();
331*c217d954SCole Faust     }
332*c217d954SCole Faust     if(_is_layer_norm_lstm)
333*c217d954SCole Faust     {
334*c217d954SCole Faust         _output_layer_norm_out1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
335*c217d954SCole Faust         _output_layer_norm_out2.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
336*c217d954SCole Faust         _memory_group.manage(&_output_layer_norm_out1);
337*c217d954SCole Faust         _memory_group.manage(&_output_layer_norm_out2);
338*c217d954SCole Faust         _mean_std_norm_output_gate.configure(compile_context, output_gate_out);
339*c217d954SCole Faust         _pixelwise_mul_output_gate_coeff.configure(compile_context, output_gate_out, lstm_params.output_layer_norm_weights(), &_output_layer_norm_out1, 1, ConvertPolicy::SATURATE,
340*c217d954SCole Faust                                                    RoundingPolicy::TO_NEAREST_EVEN);
341*c217d954SCole Faust         // output_gate_out is going to be reassigned, so allocate the tensor that it was assigned to before
342*c217d954SCole Faust         output_gate_out->allocator()->allocate();
343*c217d954SCole Faust         _accum_output_gate_bias.configure(compile_context, &_output_layer_norm_out1, output_gate_bias, &_output_layer_norm_out2, ConvertPolicy::SATURATE);
344*c217d954SCole Faust         _output_layer_norm_out1.allocator()->allocate();
345*c217d954SCole Faust         output_gate_out = &_output_layer_norm_out2;
346*c217d954SCole Faust     }
347*c217d954SCole Faust     _activation_output.configure(compile_context, output_gate_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
348*c217d954SCole Faust 
349*c217d954SCole Faust     // Configure block that calculates the output state
350*c217d954SCole Faust     /** lstm_res = PixelwiseMul(output, Activation(cell_state))
351*c217d954SCole Faust      *
352*c217d954SCole Faust      *                      -- Clip(lstm_res * projection_weights + projection_bias, projection_threshold) , if there is a projection
353*c217d954SCole Faust      *                     /
354*c217d954SCole Faust      *  output_state =  --
355*c217d954SCole Faust      *                     \
356*c217d954SCole Faust      *                      -- lstm_res , otherwise
357*c217d954SCole Faust      */
358*c217d954SCole Faust     ICLTensor *output_state_out_tmp = lstm_params.has_projection() ? &_output_state1 : output_state_out;
359*c217d954SCole Faust     _cell_state_activation.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
360*c217d954SCole Faust     _output_state1.allocator()->init(TensorInfo(cell_state_shape, 1, input->info()->data_type()));
361*c217d954SCole Faust 
362*c217d954SCole Faust     _memory_group.manage(&_cell_state_activation);
363*c217d954SCole Faust     _activation_output_state.configure(compile_context, &_cell_state_out1, &_cell_state_activation, activation_info);
364*c217d954SCole Faust     _pixelwise_mul_output_state2.configure(compile_context, &_cell_state_activation, output_gate_out, output_state_out_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
365*c217d954SCole Faust     _cell_state_activation.allocator()->allocate();
366*c217d954SCole Faust 
367*c217d954SCole Faust     if(lstm_params.has_projection())
368*c217d954SCole Faust     {
369*c217d954SCole Faust         _has_projection_weights = true;
370*c217d954SCole Faust         _fully_connected_output_state.configure(compile_context, output_state_out_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out);
371*c217d954SCole Faust         _output_state1.allocator()->allocate();
372*c217d954SCole Faust         // Perform clipping
373*c217d954SCole Faust         if(projection_threshold != 0.f)
374*c217d954SCole Faust         {
375*c217d954SCole Faust             _perform_projection_clipping = true;
376*c217d954SCole Faust             _projection_clip.configure(compile_context, output_state_out, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
377*c217d954SCole Faust         }
378*c217d954SCole Faust     }
379*c217d954SCole Faust 
380*c217d954SCole Faust     // Copy cell state and output
381*c217d954SCole Faust     _copy_cell_state.configure(compile_context, &_cell_state_out1, cell_state_out);
382*c217d954SCole Faust     _copy_output.configure(compile_context, output_state_out, output);
383*c217d954SCole Faust 
384*c217d954SCole Faust     // Vector for holding the tensors to store in scratch buffer
385*c217d954SCole Faust     std::vector<const ICLTensor *> scratch_inputs;
386*c217d954SCole Faust     if(!lstm_params.has_cifg_opt())
387*c217d954SCole Faust     {
388*c217d954SCole Faust         scratch_inputs.emplace_back(input_gate_out);
389*c217d954SCole Faust     }
390*c217d954SCole Faust     scratch_inputs.emplace_back(&_cell_state_out1);
391*c217d954SCole Faust     scratch_inputs.emplace_back(forget_gate_out);
392*c217d954SCole Faust     scratch_inputs.emplace_back(output_gate_out);
393*c217d954SCole Faust     _concat_scratch_buffer.configure(compile_context, scratch_inputs, scratch_buffer, Window::DimX);
394*c217d954SCole Faust     input_gate_out->allocator()->allocate();
395*c217d954SCole Faust     _cell_state_out1.allocator()->allocate();
396*c217d954SCole Faust     forget_gate_out->allocator()->allocate();
397*c217d954SCole Faust     output_gate_out->allocator()->allocate();
398*c217d954SCole Faust }
399*c217d954SCole Faust 
validate(const ITensorInfo * input,const ITensorInfo * input_to_forget_weights,const ITensorInfo * input_to_cell_weights,const ITensorInfo * input_to_output_weights,const ITensorInfo * recurrent_to_forget_weights,const ITensorInfo * recurrent_to_cell_weights,const ITensorInfo * recurrent_to_output_weights,const ITensorInfo * forget_gate_bias,const ITensorInfo * cell_bias,const ITensorInfo * output_gate_bias,const ITensorInfo * output_state_in,const ITensorInfo * cell_state_in,const ITensorInfo * scratch_buffer,const ITensorInfo * output_state_out,const ITensorInfo * cell_state_out,const ITensorInfo * output,const LSTMParams<ITensorInfo> & lstm_params,const ActivationLayerInfo & activation_info,float cell_threshold,float projection_threshold)400*c217d954SCole Faust Status CLLSTMLayer::validate(const ITensorInfo *input,
401*c217d954SCole Faust                              const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
402*c217d954SCole Faust                              const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
403*c217d954SCole Faust                              const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
404*c217d954SCole Faust                              const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
405*c217d954SCole Faust                              const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
406*c217d954SCole Faust                              const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold, float projection_threshold)
407*c217d954SCole Faust {
408*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input,
409*c217d954SCole Faust                                         input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
410*c217d954SCole Faust                                         recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
411*c217d954SCole Faust                                         forget_gate_bias, cell_bias, output_gate_bias,
412*c217d954SCole Faust                                         output_state_in, cell_state_in,
413*c217d954SCole Faust                                         scratch_buffer, output_state_out, cell_state_out, output);
414*c217d954SCole Faust 
415*c217d954SCole Faust     // Check data types
416*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
417*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input,
418*c217d954SCole Faust                                                        input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
419*c217d954SCole Faust                                                        recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
420*c217d954SCole Faust                                                        forget_gate_bias, cell_bias, output_gate_bias,
421*c217d954SCole Faust                                                        output_state_in, cell_state_in,
422*c217d954SCole Faust                                                        scratch_buffer, output_state_out, cell_state_out, output);
423*c217d954SCole Faust 
424*c217d954SCole Faust     // Check dimensions
425*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
426*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
427*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
428*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
429*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
430*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
431*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
432*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
433*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
434*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
435*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(output_state_in->num_dimensions() > 2);
436*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_in->num_dimensions() > 2);
437*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
438*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(output_state_out->num_dimensions() > 2);
439*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(cell_state_out->num_dimensions() > 2);
440*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
441*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0)
442*c217d954SCole Faust                                 && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
443*c217d954SCole Faust 
444*c217d954SCole Faust     const unsigned int num_batches = input->dimension(1);
445*c217d954SCole Faust     const unsigned int num_cells   = input_to_output_weights->dimension(1);
446*c217d954SCole Faust 
447*c217d954SCole Faust     if(lstm_params.use_layer_norm())
448*c217d954SCole Faust     {
449*c217d954SCole Faust         // If CIFG is used, input layer normalization weights tensor is omitted
450*c217d954SCole Faust         if(lstm_params.has_cifg_opt())
451*c217d954SCole Faust         {
452*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights() != nullptr);
453*c217d954SCole Faust         }
454*c217d954SCole Faust         else
455*c217d954SCole Faust         {
456*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights());
457*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->num_dimensions() > 1);
458*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_layer_norm_weights()->dimension(0) != num_cells);
459*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.input_layer_norm_weights());
460*c217d954SCole Faust         }
461*c217d954SCole Faust 
462*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
463*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, lstm_params.forget_layer_norm_weights(), lstm_params.cell_layer_norm_weights(), lstm_params.output_layer_norm_weights());
464*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->num_dimensions() > 1);
465*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->num_dimensions() > 1);
466*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->num_dimensions() > 1);
467*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.forget_layer_norm_weights()->dimension(0) != num_cells);
468*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_layer_norm_weights()->dimension(0) != num_cells);
469*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.output_layer_norm_weights()->dimension(0) != num_cells);
470*c217d954SCole Faust     }
471*c217d954SCole Faust 
472*c217d954SCole Faust     // Check peephole optimization
473*c217d954SCole Faust     if(lstm_params.has_peephole_opt())
474*c217d954SCole Faust     {
475*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
476*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
477*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
478*c217d954SCole Faust     }
479*c217d954SCole Faust 
480*c217d954SCole Faust     TensorShape      units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
481*c217d954SCole Faust     TensorShape      num_units_transposed_shape = compute_transposed_shape(*forget_gate_bias);
482*c217d954SCole Faust     const TensorInfo units_out_transposed_info  = TensorInfo(units_out_transposed_shape, 1, input->data_type());
483*c217d954SCole Faust     const TensorInfo num_units_transposed_info  = TensorInfo(num_units_transposed_shape, 1, input->data_type());
484*c217d954SCole Faust 
485*c217d954SCole Faust     TensorInfo input_gate      = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
486*c217d954SCole Faust     TensorInfo forget_gate     = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
487*c217d954SCole Faust     TensorInfo output_gate_tmp = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
488*c217d954SCole Faust     TensorInfo cell_state_tmp  = TensorInfo(TensorShape(num_cells, num_batches), 1, input->data_type());
489*c217d954SCole Faust 
490*c217d954SCole Faust     // Validate forget gate
491*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_forget_weights, (lstm_params.use_layer_norm()) ? nullptr : forget_gate_bias, &forget_gate));
492*c217d954SCole Faust 
493*c217d954SCole Faust     std::vector<const ITensorInfo *> inputs_vector;
494*c217d954SCole Faust     inputs_vector.emplace_back(input);
495*c217d954SCole Faust     inputs_vector.emplace_back(output_state_in);
496*c217d954SCole Faust     const TensorShape concat_shape       = arm_compute::misc::shape_calculator::calculate_concatenate_shape(inputs_vector, 0);
497*c217d954SCole Faust     TensorInfo        forget_gate_concat = TensorInfo(concat_shape, 1, input->data_type());
498*c217d954SCole Faust 
499*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector, &forget_gate_concat, Window::DimX));
500*c217d954SCole Faust 
501*c217d954SCole Faust     if(lstm_params.has_peephole_opt())
502*c217d954SCole Faust     {
503*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_forget_weights(), &forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
504*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
505*c217d954SCole Faust     }
506*c217d954SCole Faust     if(lstm_params.use_layer_norm())
507*c217d954SCole Faust     {
508*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&forget_gate));
509*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&forget_gate, lstm_params.forget_layer_norm_weights(), &forget_gate, 1, ConvertPolicy::SATURATE,
510*c217d954SCole Faust                                                                         RoundingPolicy::TO_NEAREST_EVEN));
511*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE));
512*c217d954SCole Faust     }
513*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
514*c217d954SCole Faust 
515*c217d954SCole Faust     // Validate input gate
516*c217d954SCole Faust     if(!lstm_params.has_cifg_opt())
517*c217d954SCole Faust     {
518*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(),
519*c217d954SCole Faust                                             lstm_params.recurrent_to_input_weights(),
520*c217d954SCole Faust                                             lstm_params.input_gate_bias());
521*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
522*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
523*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
524*c217d954SCole Faust 
525*c217d954SCole Faust         std::vector<const ITensorInfo *> lstm_weights;
526*c217d954SCole Faust         lstm_weights.emplace_back(lstm_params.input_to_input_weights());
527*c217d954SCole Faust         lstm_weights.emplace_back(lstm_params.recurrent_to_input_weights());
528*c217d954SCole Faust         TensorShape lstm_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(lstm_weights, 0);
529*c217d954SCole Faust         TensorInfo  lstm_gate_concat          = TensorInfo(lstm_weights_concat_shape, 1, input->data_type());
530*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(lstm_weights, &lstm_gate_concat, Window::DimX));
531*c217d954SCole Faust 
532*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), (lstm_params.use_layer_norm()) ? nullptr : lstm_params.input_gate_bias(), &input_gate));
533*c217d954SCole Faust 
534*c217d954SCole Faust         if(lstm_params.has_peephole_opt())
535*c217d954SCole Faust         {
536*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_input_weights());
537*c217d954SCole Faust             ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
538*c217d954SCole Faust             ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(cell_state_in, lstm_params.cell_to_input_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
539*c217d954SCole Faust             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, &input_gate, &input_gate, ConvertPolicy::SATURATE));
540*c217d954SCole Faust         }
541*c217d954SCole Faust 
542*c217d954SCole Faust         if(lstm_params.use_layer_norm())
543*c217d954SCole Faust         {
544*c217d954SCole Faust             ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&input_gate));
545*c217d954SCole Faust             ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
546*c217d954SCole Faust             ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE));
547*c217d954SCole Faust         }
548*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
549*c217d954SCole Faust     }
550*c217d954SCole Faust     else
551*c217d954SCole Faust     {
552*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticSubtraction::validate(&forget_gate, &forget_gate, &forget_gate, ConvertPolicy::SATURATE));
553*c217d954SCole Faust     }
554*c217d954SCole Faust 
555*c217d954SCole Faust     // Validate cell state
556*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_cell_weights, (lstm_params.use_layer_norm()) ? nullptr : cell_bias, &cell_state_tmp));
557*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(output_state_in, &units_out_transposed_info, nullptr, &cell_state_tmp, 1.f, 0.f, GEMMInfo()));
558*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
559*c217d954SCole Faust     if(lstm_params.use_layer_norm())
560*c217d954SCole Faust     {
561*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&cell_state_tmp));
562*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_layer_norm_weights(), &cell_state_tmp, 1, ConvertPolicy::SATURATE,
563*c217d954SCole Faust                                                                         RoundingPolicy::TO_NEAREST_EVEN));
564*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE));
565*c217d954SCole Faust     }
566*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info));
567*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
568*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
569*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
570*c217d954SCole Faust     if(cell_threshold != 0.f)
571*c217d954SCole Faust     {
572*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold,
573*c217d954SCole Faust                                                                                                               -cell_threshold)));
574*c217d954SCole Faust     }
575*c217d954SCole Faust 
576*c217d954SCole Faust     std::vector<const ITensorInfo *> in_out_weights;
577*c217d954SCole Faust     in_out_weights.emplace_back(input_to_output_weights);
578*c217d954SCole Faust     in_out_weights.emplace_back(recurrent_to_output_weights);
579*c217d954SCole Faust     TensorShape in_out_weights_concat_shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(in_out_weights, 0);
580*c217d954SCole Faust     TensorInfo  in_out_gate_concat          = TensorInfo(in_out_weights_concat_shape, 1, input->data_type());
581*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(in_out_weights, &in_out_gate_concat, Window::DimX));
582*c217d954SCole Faust     // Validate output gate tmp
583*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, input_to_output_weights, (lstm_params.use_layer_norm()) ? nullptr : output_gate_bias, &output_gate_tmp));
584*c217d954SCole Faust 
585*c217d954SCole Faust     if(lstm_params.has_peephole_opt())
586*c217d954SCole Faust     {
587*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, lstm_params.cell_to_output_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
588*c217d954SCole Faust                                                                         RoundingPolicy::TO_NEAREST_EVEN));
589*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, &output_gate_tmp, &output_gate_tmp, ConvertPolicy::SATURATE));
590*c217d954SCole Faust     }
591*c217d954SCole Faust     if(lstm_params.use_layer_norm())
592*c217d954SCole Faust     {
593*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLMeanStdDevNormalizationLayer::validate(&output_gate_tmp));
594*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&output_gate_tmp, lstm_params.output_layer_norm_weights(), &output_gate_tmp, 1, ConvertPolicy::SATURATE,
595*c217d954SCole Faust                                                                         RoundingPolicy::TO_NEAREST_EVEN));
596*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE));
597*c217d954SCole Faust     }
598*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
599*c217d954SCole Faust 
600*c217d954SCole Faust     // Validate output state
601*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info));
602*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplication::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN));
603*c217d954SCole Faust     if(lstm_params.has_projection())
604*c217d954SCole Faust     {
605*c217d954SCole Faust         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out));
606*c217d954SCole Faust         if(projection_threshold != 0.f)
607*c217d954SCole Faust         {
608*c217d954SCole Faust             ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, output_state_out,
609*c217d954SCole Faust                                                                     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)));
610*c217d954SCole Faust         }
611*c217d954SCole Faust     }
612*c217d954SCole Faust 
613*c217d954SCole Faust     // Validate copy kernel
614*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(&cell_state_tmp, cell_state_out));
615*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLCopy::validate(output_state_out, output));
616*c217d954SCole Faust 
617*c217d954SCole Faust     // Validate scratch concatenation
618*c217d954SCole Faust     std::vector<const ITensorInfo *> inputs_vector_info_raw;
619*c217d954SCole Faust     if(!lstm_params.has_cifg_opt())
620*c217d954SCole Faust     {
621*c217d954SCole Faust         inputs_vector_info_raw.push_back(&input_gate);
622*c217d954SCole Faust     }
623*c217d954SCole Faust     inputs_vector_info_raw.push_back(&cell_state_tmp);
624*c217d954SCole Faust     inputs_vector_info_raw.push_back(&forget_gate);
625*c217d954SCole Faust     inputs_vector_info_raw.push_back(&output_gate_tmp);
626*c217d954SCole Faust 
627*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(CLConcatenateLayer::validate(inputs_vector_info_raw, scratch_buffer, Window::DimX));
628*c217d954SCole Faust     return Status{};
629*c217d954SCole Faust }
630*c217d954SCole Faust 
run()631*c217d954SCole Faust void CLLSTMLayer::run()
632*c217d954SCole Faust {
633*c217d954SCole Faust     prepare();
634*c217d954SCole Faust 
635*c217d954SCole Faust     MemoryGroupResourceScope scope_mg(_memory_group);
636*c217d954SCole Faust 
637*c217d954SCole Faust     _concat_inputs_forget_gate.run();
638*c217d954SCole Faust 
639*c217d954SCole Faust     _fully_connected_forget_gate.run();
640*c217d954SCole Faust 
641*c217d954SCole Faust     if(_run_peephole_opt)
642*c217d954SCole Faust     {
643*c217d954SCole Faust         _pixelwise_mul_forget_gate.run();
644*c217d954SCole Faust         _accum_forget_gate1.run();
645*c217d954SCole Faust     }
646*c217d954SCole Faust     if(_is_layer_norm_lstm)
647*c217d954SCole Faust     {
648*c217d954SCole Faust         _mean_std_norm_forget_gate.run();
649*c217d954SCole Faust         _pixelwise_mul_forget_gate_coeff.run();
650*c217d954SCole Faust         _accum_forget_gate_bias.run();
651*c217d954SCole Faust     }
652*c217d954SCole Faust     _activation_forget_gate.run();
653*c217d954SCole Faust 
654*c217d954SCole Faust     if(_run_cifg_opt)
655*c217d954SCole Faust     {
656*c217d954SCole Faust         _ones_fill.run();
657*c217d954SCole Faust         _subtract_input_gate.run();
658*c217d954SCole Faust     }
659*c217d954SCole Faust     else
660*c217d954SCole Faust     {
661*c217d954SCole Faust         _fully_connected_input_gate.run();
662*c217d954SCole Faust 
663*c217d954SCole Faust         if(_run_peephole_opt)
664*c217d954SCole Faust         {
665*c217d954SCole Faust             _pixelwise_mul_input_gate.run();
666*c217d954SCole Faust             _accum_input_gate1.run();
667*c217d954SCole Faust         }
668*c217d954SCole Faust 
669*c217d954SCole Faust         if(_is_layer_norm_lstm)
670*c217d954SCole Faust         {
671*c217d954SCole Faust             _mean_std_norm_input_gate.run();
672*c217d954SCole Faust             _pixelwise_mul_input_gate_coeff.run();
673*c217d954SCole Faust             _accum_input_gate_bias.run();
674*c217d954SCole Faust         }
675*c217d954SCole Faust         _activation_input_gate.run();
676*c217d954SCole Faust     }
677*c217d954SCole Faust 
678*c217d954SCole Faust     _fully_connected_cell_state.run();
679*c217d954SCole Faust     ITensorPack pack;
680*c217d954SCole Faust     pack.add_tensor(TensorType::ACL_SRC, _recurrent_to_cell_weights);
681*c217d954SCole Faust     pack.add_tensor(TensorType::ACL_DST, &_cell_state_out2);
682*c217d954SCole Faust     CLScheduler::get().enqueue_op(*_transpose_cell_state,
683*c217d954SCole Faust                                   pack,
684*c217d954SCole Faust                                   false);
685*c217d954SCole Faust     _gemm_cell_state1.run();
686*c217d954SCole Faust     _accum_cell_state1.run();
687*c217d954SCole Faust     if(_is_layer_norm_lstm)
688*c217d954SCole Faust     {
689*c217d954SCole Faust         _mean_std_norm_cell_gate.run();
690*c217d954SCole Faust         _pixelwise_mul_cell_gate_coeff.run();
691*c217d954SCole Faust         _accum_cell_gate_bias.run();
692*c217d954SCole Faust     }
693*c217d954SCole Faust     _activation_cell_state.run();
694*c217d954SCole Faust     _pixelwise_mul_cell_state1.run();
695*c217d954SCole Faust     _pixelwise_mul_cell_state2.run();
696*c217d954SCole Faust     _accum_cell_state2.run();
697*c217d954SCole Faust 
698*c217d954SCole Faust     if(_perform_cell_clipping)
699*c217d954SCole Faust     {
700*c217d954SCole Faust         _cell_clip.run();
701*c217d954SCole Faust     }
702*c217d954SCole Faust 
703*c217d954SCole Faust     _fully_connected_output.run();
704*c217d954SCole Faust 
705*c217d954SCole Faust     if(_run_peephole_opt)
706*c217d954SCole Faust     {
707*c217d954SCole Faust         _pixelwise_mul_output_state1.run();
708*c217d954SCole Faust         _accum_output1.run();
709*c217d954SCole Faust     }
710*c217d954SCole Faust     if(_is_layer_norm_lstm)
711*c217d954SCole Faust     {
712*c217d954SCole Faust         _mean_std_norm_output_gate.run();
713*c217d954SCole Faust         _pixelwise_mul_output_gate_coeff.run();
714*c217d954SCole Faust         _accum_output_gate_bias.run();
715*c217d954SCole Faust     }
716*c217d954SCole Faust     _activation_output.run();
717*c217d954SCole Faust 
718*c217d954SCole Faust     _activation_output_state.run();
719*c217d954SCole Faust     _pixelwise_mul_output_state2.run();
720*c217d954SCole Faust 
721*c217d954SCole Faust     if(_has_projection_weights)
722*c217d954SCole Faust     {
723*c217d954SCole Faust         _fully_connected_output_state.run();
724*c217d954SCole Faust         if(_perform_projection_clipping)
725*c217d954SCole Faust         {
726*c217d954SCole Faust             _projection_clip.run();
727*c217d954SCole Faust         }
728*c217d954SCole Faust     }
729*c217d954SCole Faust 
730*c217d954SCole Faust     _copy_cell_state.run();
731*c217d954SCole Faust     _copy_output.run();
732*c217d954SCole Faust 
733*c217d954SCole Faust     _concat_scratch_buffer.run();
734*c217d954SCole Faust }
735*c217d954SCole Faust 
prepare()736*c217d954SCole Faust void CLLSTMLayer::prepare()
737*c217d954SCole Faust {
738*c217d954SCole Faust     if(!_is_prepared)
739*c217d954SCole Faust     {
740*c217d954SCole Faust         _concat_weights_forget_gate.run();
741*c217d954SCole Faust         if(!_run_cifg_opt)
742*c217d954SCole Faust         {
743*c217d954SCole Faust             _concat_weights_input_gate.run();
744*c217d954SCole Faust         }
745*c217d954SCole Faust         _concat_weights_output.run();
746*c217d954SCole Faust         _is_prepared = true;
747*c217d954SCole Faust     }
748*c217d954SCole Faust }
749*c217d954SCole Faust } // namespace arm_compute
750