xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefLstmWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13 
14 namespace armnn
15 {
16 
RefLstmWorkload(const LstmQueueDescriptor & descriptor,const WorkloadInfo & info)17 RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
18     : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info)
19     , m_InputToInputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20     , m_InputToForgetWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21     , m_InputToCellWeightsTensor      (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22     , m_InputToOutputWeightsTensor    (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23     , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24     , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25     , m_RecurrentToCellWeightsTensor  (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26     , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27     , m_CellToInputWeightsTensor      (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28     , m_CellToForgetWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29     , m_CellToOutputWeightsTensor     (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30     , m_InputGateBiasTensor           (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31     , m_ForgetGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32     , m_CellBiasTensor                (AssignScopedTensorHandle(descriptor.m_CellBias))
33     , m_OutputGateBiasTensor          (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34     , m_ProjectionWeightsTensor       (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35     , m_ProjectionBiasTensor          (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36     , m_InputLayerNormWeights         (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37     , m_ForgetLayerNormWeights        (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38     , m_CellLayerNormWeights          (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39     , m_OutputLayerNormWeights        (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40 {}
41 
Execute() const42 void RefLstmWorkload::Execute() const
43 {
44     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
45 }
46 
ExecuteAsync(ExecutionData & executionData)47 void RefLstmWorkload::ExecuteAsync(ExecutionData& executionData)
48 {
49     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
50     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
51 }
52 
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const53 void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
54 {
55     // This is a porting of the LSTM::Eval() method in the Android code base
56     // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
57 
58     const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
59     const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
60 
61     const TensorShape& inputShape = inputInfo.GetShape();
62 
63     std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
64     std::unique_ptr<Encoder<float>> cellStateOut   = MakeEncoder<float>(outputInfo, outputs[2]->Map());
65     std::unique_ptr<Encoder<float>> output         = MakeEncoder<float>(outputInfo, outputs[3]->Map());
66 
67     std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
68     std::unique_ptr<Decoder<float>> outputDecoder       = MakeDecoder<float>(outputInfo, outputs[3]->Map());
69 
70     std::unique_ptr<Decoder<float>> inputData     = MakeDecoder<float>(inputInfo, inputs[0]->Map());
71     std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
72     std::unique_ptr<Decoder<float>> cellStateIn   = MakeDecoder<float>(inputInfo, inputs[2]->Map());
73 
74     const uint32_t nBatch = inputShape[0];
75     const uint32_t nCell   = m_InputToOutputWeightsTensor->GetShape()[0];
76 
77     const bool useCifg      = m_Data.m_Parameters.m_CifgEnabled;
78     const bool usePeephole  = m_Data.m_Parameters.m_PeepholeEnabled;
79     const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
80 
81     // Index the scratch buffers pointers to the global scratch buffer.
82     std::unique_ptr<Encoder<float>> inputGateScratch  = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83     std::unique_ptr<Encoder<float>> cellScratch       = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84     std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85     std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86 
87     std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
88         MakeDecoder<float>(outputInfo, outputs[0]->Map());
89     std::unique_ptr<Decoder<float>> cellScratchDecoder =
90         MakeDecoder<float>(outputInfo, outputs[0]->Map());
91     std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
92         MakeDecoder<float>(outputInfo, outputs[0]->Map());
93     std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
94         MakeDecoder<float>(outputInfo, outputs[0]->Map());
95 
96     if (useCifg)
97     {
98         *cellScratch       += (0 * nCell * nBatch);
99         *forgetGateScratch += (1 * nCell * nBatch);
100         *outputGateScratch += (2 * nCell * nBatch);
101 
102         *cellScratchDecoder       += (0 * nCell * nBatch);
103         *forgetGateScratchDecoder += (1 * nCell * nBatch);
104         *outputGateScratchDecoder += (2 * nCell * nBatch);
105     }
106     else
107     {
108         *inputGateScratch  += (0 * nCell * nBatch);
109         *cellScratch       += (1 * nCell * nBatch);
110         *forgetGateScratch += (2 * nCell * nBatch);
111         *outputGateScratch += (3 * nCell * nBatch);
112 
113         *inputGateScratchDecoder  += (0 * nCell * nBatch);
114         *cellScratchDecoder       += (1 * nCell * nBatch);
115         *forgetGateScratchDecoder += (2 * nCell * nBatch);
116         *outputGateScratchDecoder += (3 * nCell * nBatch);
117     }
118 
119     std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
120     std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
121         m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
122     std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
123         m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
124     std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
125         m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
126 
127     std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
128     std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
129         m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
130     std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
131         m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
132     std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
133         m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
134 
135     std::unique_ptr<Decoder<float>> inputGateBiasTensor;
136     std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
137         m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
138     std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
139         m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
140     std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
141         m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
142 
143     std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
144     std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
145     std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
146 
147     std::unique_ptr<Decoder<float>> projectionWeightsTensor;
148     std::unique_ptr<Decoder<float>> projectionBiasTensor;
149 
150     std::unique_ptr<Decoder<float>> inputLayerNormWeights;
151     std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
152     std::unique_ptr<Decoder<float>> cellLayerNormWeights;
153     std::unique_ptr<Decoder<float>> outputLayerNormWeights;
154 
155     const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
156     const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
157 
158     if (useLayerNorm)
159     {
160         if (!useCifg)
161         {
162             inputLayerNormWeights = MakeDecoder<float>(
163                     m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
164         }
165         forgetLayerNormWeights = MakeDecoder<float>(
166                 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
167         cellLayerNormWeights = MakeDecoder<float>(
168                 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
169         outputLayerNormWeights = MakeDecoder<float>(
170                 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
171     }
172 
173     if (!useCifg)
174     {
175         inputToInputWeightsTensor = MakeDecoder<float>(
176             m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
177         inputGateBiasTensor = MakeDecoder<float>(
178             m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
179         recurrentToInputWeightsTensor = MakeDecoder<float>(
180             m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
181     }
182 
183     if (usePeephole)
184     {
185         cellToForgetWeightsTensor = MakeDecoder<float>(
186             m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
187         cellToOutputWeightsTensor = MakeDecoder<float>(
188             m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
189     }
190 
191     if (!useCifg && usePeephole)
192     {
193         cellToInputWeightsTensor = MakeDecoder<float>(
194             m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
195     }
196 
197     if (m_Data.m_Parameters.m_ProjectionEnabled)
198     {
199         projectionWeightsTensor = MakeDecoder<float>(
200             m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
201         if (m_ProjectionBiasTensor)
202         {
203             projectionBiasTensor = MakeDecoder<float>(
204                 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
205         }
206     }
207 
208     LstmImpl(m_Data.m_Parameters,
209                  inputInfo,
210                  outputInfo,
211                  inputToOutputWeightsShape,
212                  recurrentToOutputWeightsShape,
213                  inputData,
214                  outputStateIn,
215                  cellStateIn,
216                  outputStateOut,
217                  cellStateOut,
218                  output,
219                  cellStateOutDecoder,
220                  outputDecoder,
221                  inputToInputWeightsTensor,
222                  inputToForgetWeightsTensor,
223                  inputToCellWeightsTensor,
224                  inputToOutputWeightsTensor,
225                  recurrentToInputWeightsTensor,
226                  recurrentToForgetWeightsTensor,
227                  recurrentToCellWeightsTensor,
228                  recurrentToOutputWeightsTensor,
229                  cellToInputWeightsTensor,
230                  cellToForgetWeightsTensor,
231                  cellToOutputWeightsTensor,
232                  inputGateBiasTensor,
233                  forgetGateBiasTensor,
234                  cellBiasTensor,
235                  outputGateBiasTensor,
236                  projectionWeightsTensor,
237                  projectionBiasTensor,
238                  inputLayerNormWeights,
239                  forgetLayerNormWeights,
240                  cellLayerNormWeights,
241                  outputLayerNormWeights,
242                  inputGateScratch,
243                  cellScratch,
244                  forgetGateScratch,
245                  outputGateScratch,
246                  inputGateScratchDecoder,
247                  cellScratchDecoder,
248                  forgetGateScratchDecoder,
249                  outputGateScratchDecoder,
250                  m_LayerNormEpsilon);
251 }
252 
253 } //namespace armnn
254