1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/backends/Workload.hpp>
11 #include <armnn/backends/WorkloadData.hpp>
12 
13 #include "arm_compute/runtime/NEON/functions/NELSTMLayer.h"
14 #include "arm_compute/runtime/NEON/functions/NEPermute.h"
15 #include "arm_compute/runtime/NEON/functions/NESplit.h"
16 #include "arm_compute/runtime/NEON/functions/NEConcatenateLayer.h"
17 
18 namespace armnn
19 {
20 
21 class NeonUnidirectionalSequenceLstmFloatWorkload : public FloatWorkload<UnidirectionalSequenceLstmQueueDescriptor>
22 {
23 public:
24     NeonUnidirectionalSequenceLstmFloatWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
25                                                 const WorkloadInfo& info);
26     virtual void Execute() const override;
27 
28 private:
29 
30     //
31     // ACL layers required to fully form a Unidirectional Sequence LSTM layer.
32     //
33 
34     // permutation for input (only used when input is batch major)
35     mutable std::unique_ptr<arm_compute::NEPermute> m_Permute1;
36     mutable std::unique_ptr<arm_compute::IFunction> m_Splitter;
37     mutable std::vector<std::unique_ptr<arm_compute::NELSTMLayer>> m_Layers;
38     mutable std::unique_ptr<arm_compute::NEConcatenateLayer> m_Concat;
39     // permutation for output (only used when input is batch major)
40     mutable std::unique_ptr<arm_compute::NEPermute> m_Permute2;
41 
42     //
43     // ACL LSTM arm_compute::Tensors.
44     //
45     std::unique_ptr<arm_compute::Tensor> m_InputToInputWeightsTensor;
46     std::unique_ptr<arm_compute::Tensor> m_InputToForgetWeightsTensor;
47     std::unique_ptr<arm_compute::Tensor> m_InputToCellWeightsTensor;
48     std::unique_ptr<arm_compute::Tensor> m_InputToOutputWeightsTensor;
49     std::unique_ptr<arm_compute::Tensor> m_RecurrentToInputWeightsTensor;
50     std::unique_ptr<arm_compute::Tensor> m_RecurrentToForgetWeightsTensor;
51     std::unique_ptr<arm_compute::Tensor> m_RecurrentToCellWeightsTensor;
52     std::unique_ptr<arm_compute::Tensor> m_RecurrentToOutputWeightsTensor;
53     std::unique_ptr<arm_compute::Tensor> m_CellToInputWeightsTensor;
54     std::unique_ptr<arm_compute::Tensor> m_CellToForgetWeightsTensor;
55     std::unique_ptr<arm_compute::Tensor> m_CellToOutputWeightsTensor;
56     std::unique_ptr<arm_compute::Tensor> m_InputGateBiasTensor;
57     std::unique_ptr<arm_compute::Tensor> m_ForgetGateBiasTensor;
58     std::unique_ptr<arm_compute::Tensor> m_CellBiasTensor;
59     std::unique_ptr<arm_compute::Tensor> m_OutputGateBiasTensor;
60     std::unique_ptr<arm_compute::Tensor> m_ProjectionWeightsTensor;
61     std::unique_ptr<arm_compute::Tensor> m_ProjectionBiasTensor;
62 
63     std::unique_ptr<arm_compute::Tensor> m_ScratchBuffer;
64 
65     std::unique_ptr<arm_compute::Tensor> m_InputLayerNormWeightsTensor;
66     std::unique_ptr<arm_compute::Tensor> m_ForgetLayerNormWeightsTensor;
67     std::unique_ptr<arm_compute::Tensor> m_CellLayerNormWeightsTensor;
68     std::unique_ptr<arm_compute::Tensor> m_OutputLayerNormWeightsTensor;
69 
70     //
71     // Additional ACL arm_compute::Tensors and std::vector<arm_compute::Tensor>.
72     // Required to perform splitting, concatenation and permutations.
73     //
74     arm_compute::Tensor m_PermuteFirstOut;
75     std::vector<arm_compute::Tensor> m_SplitterOutputsTensors;
76     std::vector<arm_compute::Tensor> m_ConcatInputsTensors;
77     std::vector<arm_compute::ITensor*> m_SplitterOutputs;
78     std::vector<const arm_compute::ITensor*> m_ConcatInputs;
79     arm_compute::Tensor concat_out;
80 
81     void FreeUnusedTensors();
82 };
83 
84 arm_compute::Status
85 NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
86                                                     const TensorInfo& outputStateIn,
87                                                     const TensorInfo& cellStateIn,
88                                                     const TensorInfo& outputStateOut,
89                                                     const TensorInfo& cellStateOut,
90                                                     const TensorInfo& output,
91                                                     const UnidirectionalSequenceLstmDescriptor& descriptor,
92                                                     const LstmInputParamsInfo& paramsInfo);
93 
94 } //namespace armnn
95