1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Descriptors.hpp> 9 #include <armnn/LstmParams.hpp> 10 #include <armnn/backends/Workload.hpp> 11 #include <armnn/backends/WorkloadData.hpp> 12 13 #include "arm_compute/runtime/NEON/functions/NELSTMLayer.h" 14 #include "arm_compute/runtime/NEON/functions/NEPermute.h" 15 #include "arm_compute/runtime/NEON/functions/NESplit.h" 16 #include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h" 17 18 namespace armnn 19 { 20 21 class NeonUnidirectionalSequenceLstmFloatWorkload : public FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor> 22 { 23 public: 24 NeonUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, 25 const WorkloadInfo& info); 26 virtual void Execute() const override; 27 28 private: 29 30 // 31 // ACL layers required to fully form a Unidirectional Sequence LSTM layer. 32 // 33 34 // permutation for input (only used when input is batch major) 35 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1; 36 mutable std::unique_ptr<arm_compute::IFunction> m_Splitter; 37 mutable std::vector<std::unique_ptr<arm_compute::NELSTMLayer>> m_Layers; 38 mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat; 39 // permutation for output (only used when input is batch major) 40 mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2; 41 42 // 43 // ACL LSTM arm_compute::Tensors. 44 // 45 std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor; 46 std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor; 47 std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor; 48 std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor; 49 std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor; 50 std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor; 51 std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor; 52 std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor; 53 std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor; 54 std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor; 55 std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor; 56 std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor; 57 std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor; 58 std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor; 59 std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor; 60 std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor; 61 std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor; 62 63 std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer; 64 65 std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor; 66 std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor; 67 std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor; 68 std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor; 69 70 // 71 // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>. 72 // Required to perform splitting, concatenation and permutations. 73 // 74 arm_compute::Tensor m_PermuteFirstOut; 75 std::vector<arm_compute::Tensor> m_SplitterOutputsTensors; 76 std::vector<arm_compute::Tensor> m_ConcatInputsTensors; 77 std::vector<arm_compute::ITensor*> m_SplitterOutputs; 78 std::vector<const arm_compute::ITensor*> m_ConcatInputs; 79 arm_compute::Tensor concat_out; 80 81 void FreeUnusedTensors(); 82 }; 83 84 arm_compute::Status 85 NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, 86 const TensorInfo& outputStateIn, 87 const TensorInfo& cellStateIn, 88 const TensorInfo& outputStateOut, 89 const TensorInfo& cellStateOut, 90 const TensorInfo& output, 91 const UnidirectionalSequenceLstmDescriptor& descriptor, 92 const LstmInputParamsInfo& paramsInfo); 93 94 } //namespace armnn 95