1 /* 2 * Copyright (c) 2020-2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_NEQLSTMLAYER_H 25 #define ARM_COMPUTE_NEQLSTMLAYER_H 26 27 #include "arm_compute/core/Types.h" 28 #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" 29 #include "arm_compute/runtime/NEON/functions/NEArithmeticAddition.h" 30 #include "arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h" 31 #include "arm_compute/runtime/NEON/functions/NECopy.h" 32 #include "arm_compute/runtime/NEON/functions/NEDequantizationLayer.h" 33 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h" 34 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h" 35 #include "arm_compute/runtime/NEON/functions/NEPixelWiseMultiplication.h" 36 #include "arm_compute/runtime/NEON/functions/NEQuantizationLayer.h" 37 #include "arm_compute/runtime/NEON/functions/NETranspose.h" 38 #include "arm_compute/runtime/common/LSTMParams.h" 39 40 #include <memory> 41 42 namespace arm_compute 43 { 44 // Forward declarations 45 class ITensor; 46 class ITensorInfo; 47 class NEQLSTMLayerNormalizationKernel; 48 namespace cpu 49 { 50 namespace kernels 51 { 52 class CpuGemmLowpMatrixAReductionKernel; 53 } // namespace kernels 54 } // namespace cpu 55 /** Basic function to run @ref NEQLSTMLayer 56 * 57 * This function calls the following kernels: 58 * 59 * -# @ref NEActivationLayer Activation functions (tanh and logistic) 60 * -# @ref NEArithmeticAddition Elementwise addition 61 * -# @ref NEArithmeticSubtraction Elementwise subtraction 62 * -# @ref NECopy Copy kernel for copying output_state_out to output 63 * -# @ref NEGEMMLowpMatrixMultiplyCore Quantized matrix multiplication core. Accumulators are 32-bit integers 64 * -# @ref NEGEMMLowpOutputStage Convert 32-bit integers into QSYMM16 65 * -# @ref cpu::kernels::CpuGemmLowpMatrixAReductionKernel For precomputing effective biases to use 66 * -# @ref NEPixelWiseMultiplication Elementwise multiplication 67 * -# @ref NETranspose Transpose function for reshaping the weights 68 * */ 69 class NEQLSTMLayer : public IFunction 70 { 71 public: 72 /** Default constructor */ 73 NEQLSTMLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr); 74 /** Prevent instances of this class from being copied (As this class contains pointers) */ 75 NEQLSTMLayer(const NEQLSTMLayer &) = delete; 76 /** Prevent instances of this class from being moved (As this class contains pointers) */ 77 NEQLSTMLayer(NEQLSTMLayer &&) = delete; 78 /** Prevent instances of this class from being copied (As this class contains pointers) */ 79 NEQLSTMLayer &operator=(const NEQLSTMLayer &) = delete; 80 /** Prevent instances of this class from being moved (As this class contains pointers) */ 81 NEQLSTMLayer &operator=(NEQLSTMLayer &&) = delete; 82 /** Default destructor */ 83 ~NEQLSTMLayer(); 84 /** Initialize function's tensors. 85 * 86 * Valid data layouts: 87 * - All 88 * 89 * Valid data type configurations: 90 * |src0 |src1 - src6 |src7 -src9 |src10 |src11 |dst0 |dst1 - dst2 | 91 * |:-------------|:------------|:------------|:------|:-------------|:------|:-----------------| 92 * |QASYMM8_SIGNED|QASYMM8 |S32 |QSYMM16|QASYMM8_SIGNED|QSYMM16|QASYMM8_SIGNED | 93 * 94 * @param[in] input Source tensor. Input is a 2D tensor with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED. 95 * @param[in] input_to_forget_weights 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8. 96 * @param[in] input_to_cell_weights 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8. 97 * @param[in] input_to_output_weights 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8. 98 * @param[in] recurrent_to_forget_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 99 * @param[in] recurrent_to_cell_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 100 * @param[in] recurrent_to_output_weights 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 101 * @param[in] forget_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: S32. 102 * @param[in] cell_bias 1D weights tensor with dimensions [num_units]. Data type supported: S32. 103 * @param[in] output_gate_bias 1D weights tensor with dimensions [num_units]. Data type supported: S32. 104 * @param[in] cell_state_in 2D tensor with dimensions [num_units, batch_size]. Data type supported: QSYMM16. 105 * @param[in] output_state_in 2D tensor with dimensions [output_size, batch_size]. Data type supported: Same as @p input. 106 * @param[out] cell_state_out Destination tensor. Output is a 2D tensor with dimensions [num_units, batch_size]. Data type supported: QSYMM16. 107 * @param[out] output_state_out Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input. 108 * @param[out] output Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input. 109 * @param[in] lstm_params Weights tensors used in peephole, CIFG and layer normalization optimizations: 110 * input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate. 111 * forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate. 112 * cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate. 113 * output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate. 114 * hidden_state_zero The zero point of the hidden state. 115 * hidden_state_scale The scale of the hidden state. 116 * input_to_input_weights (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8. 117 * recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 118 * cell_to_input_weights (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16. 119 * cell_to_forget_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 120 * cell_to_output_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 121 * input_gate_bias (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32. 122 * projection_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 123 * projection_bias (Optional) 1D weights tensor with dimensions [output_size]. S32. 124 * input_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 125 * forget_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 126 * cell_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 127 * output_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 128 * cell_threshold (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip]. 129 * If set to 0.0 then clipping is disabled. 130 * projection_threshold (Optional) The clipping threshold for the output from the projection layer, such that values are bound within 131 * [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled. 132 */ 133 void configure(const ITensor *input, 134 const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights, 135 const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights, 136 const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias, 137 const ITensor *cell_state_in, ITensor *output_state_in, 138 ITensor *cell_state_out, ITensor *output_state_out, ITensor *output, 139 const LSTMParams<ITensor> &lstm_params); 140 141 /** Static function to check if given info will lead to a valid configuration of @ref NEQLSTMLayer 142 * 143 * @param[in] input Source tensor info. Input is a 2D tensor info with dimensions [input_size, batch_size]. Data types supported: QASYMM8_SIGNED. 144 * @param[in] input_to_forget_weights 2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8. 145 * @param[in] input_to_cell_weights 2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8. 146 * @param[in] input_to_output_weights 2D weights tensor info with dimensions [input_size, num_units]. Data type supported: QSYMM8. 147 * @param[in] recurrent_to_forget_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8. 148 * @param[in] recurrent_to_cell_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8. 149 * @param[in] recurrent_to_output_weights 2D weights tensor info with dimensions [output_size, num_units]. Data type supported: QSYMM8. 150 * @param[in] forget_gate_bias 1D weights tensor info with dimensions [num_units]. Data type supported: S32. 151 * @param[in] cell_bias 1D weights tensor info with dimensions [num_units]. Data type supported: S32. 152 * @param[in] output_gate_bias 1D weights tensor info with dimensions [num_units]. Data type supported: S32. 153 * @param[in] cell_state_in 2D tensor info with dimensions [num_units, batch_size]. Data type supported: QSYMM16. 154 * @param[in] output_state_in 2D tensor info with dimensions [output_size, batch_size]. Data type supported: Same as @p input. 155 * @param[in] cell_state_out Destination tensor info. Output is a 2D tensor info with dimensions [num_units, batch_size]. Data type supported: QSYMM16. 156 * @param[in] output_state_out Destination tensor info. Output is a 2D tensor info with dimensions [output_size, batch_size].Data types supported: Same as @p input. 157 * @param[in] output Destination tensor info. Output is a 2D tensor info with dimensions [output_size, batch_size].Data types supported: Same as @p input. 158 * @param[in] lstm_params Weights tensors info used in peephole, CIFG and layer normalization optimizations: 159 * input_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate. 160 * forget_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate. 161 * cell_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate. 162 * output_intermediate_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate. 163 * hidden_state_zero The zero point of the hidden state. 164 * hidden_state_scale The scale of the hidden state. 165 * input_to_input_weights (Optional) 2D weights tensor with dimensions [input_size, num_units]. Data type supported: QSYMM8. 166 * recurrent_to_input_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 167 * cell_to_input_weights (Optional) 1D weights tensor with dimensions [num_units]. Can be nullptr. Data type supported: QSYMM16. 168 * cell_to_forget_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 169 * cell_to_output_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 170 * input_gate_bias (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: S32. 171 * projection_weights (Optional) 2D weights tensor with dimensions [output_size, num_units]. Data type supported: QSYMM8. 172 * projection_bias (Optional) 1D weights tensor with dimensions [output_size]. S32. 173 * input_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 174 * forget_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 175 * cell_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 176 * output_layer_norm_weights (Optional) 1D weights tensor with dimensions [num_units]. Data type supported: QSYMM16. 177 * cell_threshold (Optional) The clipping threshold for the cell state, such that values are bound within [-cell_clip, cell_clip]. 178 * If set to 0.0 then clipping is disabled. 179 * projection_threshold (Optional) The clipping threshold for the output from the projection layer, such that values are bound within 180 * [-proj_clip, proj_clip]. If set to 0.0 then clipping is disabled. 181 * @return a status 182 */ 183 static Status validate(const ITensorInfo *input, 184 const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights, 185 const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights, 186 const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias, 187 const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in, 188 const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output, 189 const LSTMParams<ITensorInfo> &lstm_params); 190 191 // Inherited methods overridden: 192 void run() override; 193 void prepare() override; 194 195 private: 196 enum class LayerNormGate : uint8_t 197 { 198 Forget, 199 Cell, 200 Input, 201 Output, 202 Count 203 }; 204 static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count); 205 static constexpr uint32_t _out_state_output_size_dimension_idx = 0; 206 207 /** Internal method to configure matrix multiplication plus output stage of each gate. 208 * 209 * @param[in] mm Matrix multiplication function to use. 210 * @param[in] outstage Output stage function to use. 211 * @param[in] gemmlowp_info GEMMLowp metadata to be used by the output stage. 212 * @param[in] mm_input Input tensor to matrix multiplication function. 213 * @param[in] mm_weights Weights tensor to matrix multiplication function. 214 * @param[in] bias Bias tensor to matrix multiplication function. 215 * @param[in] outstage_res Tensor to be used for storing the result of the output stage. 216 * @param[in] gemmlowp_scale Real multiplier to be used computing multiplier and shift for requantization. 217 * @param[in] mm_res_info Tensor info to be used to initialize matrix multiplication result tensor. 218 * @param[in] mm_res_info Tensor info to be used to initialize output stage result tensor. 219 * 220 */ 221 void configure_mm(NEGEMMLowpMatrixMultiplyCore &mm, NEGEMMLowpOutputStage &outstage, GEMMLowpOutputStageInfo &gemmlowp_info, 222 const ITensor *mm_input, const ITensor *mm_weights, const ITensor *bias, Tensor *mm_res, 223 Tensor *outstage_res, float gemmlowp_scale, 224 const TensorInfo &mm_res_info, const TensorInfo &outstage_tensor_info); 225 226 MemoryGroup _memory_group; 227 228 /** A small internel kernel do the copy between two tensors */ 229 class TensorCopyKernel 230 { 231 static constexpr uint32_t max_dimension_supported = 2; 232 233 ITensor *_src{ nullptr }; 234 ITensor *_dst{ nullptr }; 235 size_t _row_size{}; 236 Window _window{}; 237 238 public: 239 /** Destructor */ 240 ~TensorCopyKernel(); 241 /** Static function to check if given info will lead to a valid configuration of @ref NEQLSTMLayer::TensorCopyKernel 242 * 243 * @param[in] src Source tensor info. 244 * @param[in] dst Destination tensor info 245 * 246 * @return a status 247 */ 248 static Status validate(const ITensorInfo &src, const ITensorInfo &dst); 249 /** Set the input and output tensors. 250 * 251 * @param[in] src Source tensor 252 * @param[out] dst Destination tensor 253 */ 254 void configure(ITensor &src, ITensor &dst); 255 /** run the kernel */ 256 void run(); 257 }; 258 259 // Functions used 260 261 NEDequantizationLayer _dequantize_input_to_forget_weights; 262 NEQuantizationLayer _quantize_input_to_forget_weights; 263 NETranspose _transpose_input_to_forget_weights; 264 NETranspose _transpose_input_to_cell_weights; 265 NETranspose _transpose_input_to_output_weights; 266 NETranspose _transpose_input_to_input_weights; 267 NETranspose _transpose_recurrent_to_forget_weights; 268 NETranspose _transpose_recurrent_to_cell_weights; 269 NETranspose _transpose_recurrent_to_output_weights; 270 NETranspose _transpose_recurrent_to_input_weights; 271 NETranspose _transpose_projection_weights; 272 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_input_reduction; 273 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_input_reduction; 274 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_forget_reduction; 275 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_forget_reduction; 276 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_cell_reduction; 277 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_cell_reduction; 278 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _input_to_output_reduction; 279 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _recurrent_to_output_reduction; 280 std::unique_ptr<cpu::kernels::CpuGemmLowpMatrixAReductionKernel> _projection_reduction; 281 NEArithmeticAddition _projection_bias_add; 282 NEGEMMLowpMatrixMultiplyCore _mm_input_to_forget; 283 NEGEMMLowpMatrixMultiplyCore _mm_recurrent_to_forget; 284 NEPixelWiseMultiplication _pixelwise_mul_cell_to_forget; 285 NEGEMMLowpOutputStage _input_to_forget_outstage; 286 NEGEMMLowpOutputStage _recurrent_to_forget_outstage; 287 NEGEMMLowpOutputStage _cell_to_forget_outstage; 288 NEArithmeticAddition _accumulate_input_recurrent_forget; 289 NEArithmeticAddition _accumulate_cell_forget; 290 NEActivationLayer _forget_gate_sigmoid; 291 NEGEMMLowpMatrixMultiplyCore _mm_input_to_cell; 292 NEGEMMLowpOutputStage _input_to_cell_outstage; 293 NEGEMMLowpMatrixMultiplyCore _mm_recurrent_to_cell; 294 NEGEMMLowpOutputStage _recurrent_to_cell_outstage; 295 NEArithmeticAddition _accumulate_input_recurrent_modulation; 296 NEActivationLayer _cell_gate_tanh; 297 NEArithmeticSubtraction _input_gate_sub; 298 NEGEMMLowpMatrixMultiplyCore _mm_input_to_input; 299 NEGEMMLowpOutputStage _input_to_input_outstage; 300 NEGEMMLowpMatrixMultiplyCore _mm_recurrent_to_input; 301 NEGEMMLowpOutputStage _recurrent_to_input_outstage; 302 NEArithmeticAddition _accumulate_input_recurrent_input; 303 NEPixelWiseMultiplication _pixelwise_mul_cell_to_input; 304 NEGEMMLowpOutputStage _cell_to_input_outstage; 305 NEArithmeticAddition _accumulate_cell_input; 306 NEActivationLayer _input_gate_sigmoid; 307 NEPixelWiseMultiplication _pixelwise_mul_forget_cell; 308 NEPixelWiseMultiplication _pixelwise_mul_input_cell; 309 NEArithmeticAddition _add_forget_cell; 310 NEActivationLayer _cell_clip; 311 NEGEMMLowpMatrixMultiplyCore _mm_input_to_output; 312 NEGEMMLowpOutputStage _input_to_output_outstage; 313 NEGEMMLowpMatrixMultiplyCore _mm_recurrent_to_output; 314 NEGEMMLowpOutputStage _recurrent_to_output_outstage; 315 NEArithmeticAddition _accumulate_input_recurrent_output; 316 NEPixelWiseMultiplication _pixelwise_mul_cell_to_output; 317 NEGEMMLowpOutputStage _cell_to_output_outstage; 318 NEArithmeticAddition _accumulate_cell_to_output; 319 NEActivationLayer _output_gate_sigmoid; 320 NEActivationLayer _hidden_tanh; 321 NEPixelWiseMultiplication _pixelwise_mul_hidden; 322 NEGEMMLowpOutputStage _hidden_outstage; 323 NEGEMMLowpMatrixMultiplyCore _mm_projection; 324 NEGEMMLowpOutputStage _projection_outstage; 325 NEArithmeticAddition _accumulate_projection; 326 NEActivationLayer _projection_clip; 327 328 TensorCopyKernel _projection_bias_copy; 329 TensorCopyKernel _projection_output_to_accumulate_copy; 330 TensorCopyKernel _projection_accumulate_to_output_copy; 331 TensorCopyKernel _hidden_to_output_copy; 332 333 std::array<std::unique_ptr<NEQLSTMLayerNormalizationKernel>, _layer_norm_count> _layer_norms; 334 335 NECopy _copy_output; 336 337 // Tensor pointers 338 const ITensor *_input_to_input_weights 339 { 340 nullptr 341 }; 342 const ITensor *_recurrent_to_input_weights{ nullptr }; 343 const ITensor *_projection_bias{ nullptr }; 344 const ITensor *_input_to_forget_weights{ nullptr }; 345 const ITensor *_input_to_cell_weights{ nullptr }; 346 const ITensor *_input_to_output_weights{ nullptr }; 347 const ITensor *_recurrent_to_forget_weights{ nullptr }; 348 const ITensor *_recurrent_to_cell_weights{ nullptr }; 349 const ITensor *_recurrent_to_output_weights{ nullptr }; 350 const ITensor *_projection_weights{ nullptr }; 351 std::array<const ITensor *, _layer_norm_count> _layer_norm_weights{}; 352 std::array<const ITensor *, _layer_norm_count> _layer_norm_bias{}; 353 354 using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type; getGateIndex(LayerNormGate g)355 inline LayerNormIndexType getGateIndex(LayerNormGate g) 356 { 357 return static_cast<LayerNormIndexType>(g); 358 } 359 set_layer_norm_weight(const ITensor * t,LayerNormGate g)360 inline void set_layer_norm_weight(const ITensor *t, LayerNormGate g) 361 { 362 _layer_norm_weights[getGateIndex(g)] = t; 363 } 364 set_layer_norm_bias(const ITensor * t,LayerNormGate g)365 inline void set_layer_norm_bias(const ITensor *t, LayerNormGate g) 366 { 367 _layer_norm_bias[getGateIndex(g)] = t; 368 } 369 get_layer_norm_weight(LayerNormGate g)370 inline const ITensor *get_layer_norm_weight(LayerNormGate g) 371 { 372 return _layer_norm_weights[getGateIndex(g)]; 373 } 374 get_layer_norm_bias(LayerNormGate g)375 inline const ITensor *get_layer_norm_bias(LayerNormGate g) 376 { 377 return _layer_norm_bias[getGateIndex(g)]; 378 } 379 get_layer_norm(LayerNormGate g)380 inline std::unique_ptr<NEQLSTMLayerNormalizationKernel> &get_layer_norm(LayerNormGate g) 381 { 382 return _layer_norms[getGateIndex(g)]; 383 } 384 385 void configure_layer_norm(LayerNormGate g, const ITensor *in); 386 static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias); 387 388 // Temporary tensors 389 Tensor _input_to_forget_weights_f32{ nullptr }; 390 Tensor _input_to_forget_weights_symm8{ nullptr }; 391 392 Tensor _input_to_forget_weights_transposed{ nullptr }; 393 Tensor _input_to_cell_weights_transposed{ nullptr }; 394 Tensor _input_to_output_weights_transposed{ nullptr }; 395 Tensor _input_to_input_weights_transposed{ nullptr }; 396 Tensor _recurrent_to_forget_weights_transposed{ nullptr }; 397 Tensor _recurrent_to_cell_weights_transposed{ nullptr }; 398 Tensor _recurrent_to_output_weights_transposed{ nullptr }; 399 Tensor _recurrent_to_input_weights_transposed{ nullptr }; 400 Tensor _projection_weights_transposed{ nullptr }; 401 Tensor _input_to_input_eff_bias{ nullptr }; 402 Tensor _recurrent_to_input_eff_bias{ nullptr }; 403 Tensor _input_to_forget_eff_bias{ nullptr }; 404 Tensor _recurrent_to_forget_eff_bias{ nullptr }; 405 Tensor _input_to_cell_eff_bias{ nullptr }; 406 Tensor _recurrent_to_cell_eff_bias{ nullptr }; 407 Tensor _input_to_output_eff_bias{ nullptr }; 408 Tensor _recurrent_to_output_eff_bias{ nullptr }; 409 Tensor _projection_reduction_res{ nullptr }; 410 Tensor _projection_eff_bias{ nullptr }; 411 Tensor _mm_input_to_forget_res{ nullptr }; 412 Tensor _mm_recurrent_to_forget_res{ nullptr }; 413 Tensor _mul_cell_to_forget_res{ nullptr }; 414 Tensor _input_to_forget_outstage_res{ nullptr }; 415 Tensor _cell_to_forget_outstage_res{ nullptr }; 416 Tensor _recurrent_to_forget_outstage_res{ nullptr }; 417 Tensor _forget_gate{ nullptr }; 418 Tensor _mm_input_to_cell_res{ nullptr }; 419 Tensor _input_to_cell_outstage_res{ nullptr }; 420 Tensor _mm_recurrent_to_cell_res{ nullptr }; 421 Tensor _recurrent_to_cell_outstage_res{ nullptr }; 422 Tensor _cell_gate{ nullptr }; 423 Tensor _mul_input_cell_res{ nullptr }; 424 Tensor _mm_input_to_input_res{ nullptr }; 425 Tensor _input_to_input_outstage_res{ nullptr }; 426 Tensor _mm_recurrent_to_input_res{ nullptr }; 427 Tensor _mul_cell_to_input_res{ nullptr }; 428 Tensor _cell_to_input_outstage_res{ nullptr }; 429 Tensor _recurrent_to_input_outstage_res{ nullptr }; 430 Tensor _input_gate{ nullptr }; 431 Tensor _mm_input_to_output_res{ nullptr }; 432 Tensor _input_to_output_outstage_res{ nullptr }; 433 Tensor _mm_recurrent_to_output_res{ nullptr }; 434 Tensor _mul_cell_to_output_res{ nullptr }; 435 Tensor _cell_to_output_outstage_res{ nullptr }; 436 Tensor _recurrent_to_output_outstage_res{ nullptr }; 437 Tensor _output_gate{ nullptr }; 438 Tensor _hidden_mul_res{ nullptr }; 439 Tensor _hidden_gate{ nullptr }; 440 Tensor _mm_projection_res{ nullptr }; 441 Tensor _projection_outstage_res{ nullptr }; 442 Tensor _projection_out_res{ nullptr }; 443 Tensor _projection_accumulate_res{ nullptr }; 444 Tensor _ones{ nullptr }; 445 std::array<Tensor, _layer_norm_count> _layer_norm_output{}; 446 get_layer_norm_output(LayerNormGate g)447 inline Tensor &get_layer_norm_output(LayerNormGate g) 448 { 449 return _layer_norm_output[getGateIndex(g)]; 450 } 451 452 bool _is_prepared{ false }; 453 bool _has_cifg{ false }; 454 bool _has_cell_clipping{ false }; 455 bool _has_projection{ false }; 456 bool _has_projection_clipping{ false }; 457 bool _has_peephole{ false }; 458 bool _has_layer_norm{ false }; 459 bool _projection_tensor_copy_required{ false }; 460 bool _convert_input_to_forget_weights_to_qsymm8{ false }; 461 }; 462 } // namespace arm_compute 463 #endif /* ARM_COMPUTE_NEQLSTMLAYER_H */ 464