1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefLstmWorkload.hpp"
7 #include "Activation.hpp"
8 #include "Encoders.hpp"
9 #include "Decoders.hpp"
10 #include "Lstm.hpp"
11 #include "LstmUtils.hpp"
12 #include "RefWorkloadUtils.hpp"
13
14 namespace armnn
15 {
16
RefLstmWorkload(const LstmQueueDescriptor & descriptor,const WorkloadInfo & info)17 RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
18 : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info)
19 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
40 {}
41
Execute() const42 void RefLstmWorkload::Execute() const
43 {
44 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
45 }
46
ExecuteAsync(ExecutionData & executionData)47 void RefLstmWorkload::ExecuteAsync(ExecutionData& executionData)
48 {
49 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
50 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
51 }
52
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const53 void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
54 {
55 // This is a porting of the LSTM::Eval() method in the Android code base
56 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
57
58 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
59 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
60
61 const TensorShape& inputShape = inputInfo.GetShape();
62
63 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
64 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
65 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
66
67 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
68 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
69
70 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
71 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
72 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
73
74 const uint32_t nBatch = inputShape[0];
75 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
76
77 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
78 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
79 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
80
81 // Index the scratch buffers pointers to the global scratch buffer.
82 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86
87 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
88 MakeDecoder<float>(outputInfo, outputs[0]->Map());
89 std::unique_ptr<Decoder<float>> cellScratchDecoder =
90 MakeDecoder<float>(outputInfo, outputs[0]->Map());
91 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
92 MakeDecoder<float>(outputInfo, outputs[0]->Map());
93 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
94 MakeDecoder<float>(outputInfo, outputs[0]->Map());
95
96 if (useCifg)
97 {
98 *cellScratch += (0 * nCell * nBatch);
99 *forgetGateScratch += (1 * nCell * nBatch);
100 *outputGateScratch += (2 * nCell * nBatch);
101
102 *cellScratchDecoder += (0 * nCell * nBatch);
103 *forgetGateScratchDecoder += (1 * nCell * nBatch);
104 *outputGateScratchDecoder += (2 * nCell * nBatch);
105 }
106 else
107 {
108 *inputGateScratch += (0 * nCell * nBatch);
109 *cellScratch += (1 * nCell * nBatch);
110 *forgetGateScratch += (2 * nCell * nBatch);
111 *outputGateScratch += (3 * nCell * nBatch);
112
113 *inputGateScratchDecoder += (0 * nCell * nBatch);
114 *cellScratchDecoder += (1 * nCell * nBatch);
115 *forgetGateScratchDecoder += (2 * nCell * nBatch);
116 *outputGateScratchDecoder += (3 * nCell * nBatch);
117 }
118
119 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
120 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
121 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
122 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
123 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
124 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
125 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
126
127 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
128 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
129 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
130 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
131 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
132 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
133 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
134
135 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
136 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
137 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
138 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
139 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
140 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
141 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
142
143 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
144 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
145 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
146
147 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
148 std::unique_ptr<Decoder<float>> projectionBiasTensor;
149
150 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
151 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
152 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
153 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
154
155 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
156 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
157
158 if (useLayerNorm)
159 {
160 if (!useCifg)
161 {
162 inputLayerNormWeights = MakeDecoder<float>(
163 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
164 }
165 forgetLayerNormWeights = MakeDecoder<float>(
166 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
167 cellLayerNormWeights = MakeDecoder<float>(
168 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
169 outputLayerNormWeights = MakeDecoder<float>(
170 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
171 }
172
173 if (!useCifg)
174 {
175 inputToInputWeightsTensor = MakeDecoder<float>(
176 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
177 inputGateBiasTensor = MakeDecoder<float>(
178 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
179 recurrentToInputWeightsTensor = MakeDecoder<float>(
180 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
181 }
182
183 if (usePeephole)
184 {
185 cellToForgetWeightsTensor = MakeDecoder<float>(
186 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
187 cellToOutputWeightsTensor = MakeDecoder<float>(
188 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
189 }
190
191 if (!useCifg && usePeephole)
192 {
193 cellToInputWeightsTensor = MakeDecoder<float>(
194 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
195 }
196
197 if (m_Data.m_Parameters.m_ProjectionEnabled)
198 {
199 projectionWeightsTensor = MakeDecoder<float>(
200 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
201 if (m_ProjectionBiasTensor)
202 {
203 projectionBiasTensor = MakeDecoder<float>(
204 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
205 }
206 }
207
208 LstmImpl(m_Data.m_Parameters,
209 inputInfo,
210 outputInfo,
211 inputToOutputWeightsShape,
212 recurrentToOutputWeightsShape,
213 inputData,
214 outputStateIn,
215 cellStateIn,
216 outputStateOut,
217 cellStateOut,
218 output,
219 cellStateOutDecoder,
220 outputDecoder,
221 inputToInputWeightsTensor,
222 inputToForgetWeightsTensor,
223 inputToCellWeightsTensor,
224 inputToOutputWeightsTensor,
225 recurrentToInputWeightsTensor,
226 recurrentToForgetWeightsTensor,
227 recurrentToCellWeightsTensor,
228 recurrentToOutputWeightsTensor,
229 cellToInputWeightsTensor,
230 cellToForgetWeightsTensor,
231 cellToOutputWeightsTensor,
232 inputGateBiasTensor,
233 forgetGateBiasTensor,
234 cellBiasTensor,
235 outputGateBiasTensor,
236 projectionWeightsTensor,
237 projectionBiasTensor,
238 inputLayerNormWeights,
239 forgetLayerNormWeights,
240 cellLayerNormWeights,
241 outputLayerNormWeights,
242 inputGateScratch,
243 cellScratch,
244 forgetGateScratch,
245 outputGateScratch,
246 inputGateScratchDecoder,
247 cellScratchDecoder,
248 forgetGateScratchDecoder,
249 outputGateScratchDecoder,
250 m_LayerNormEpsilon);
251 }
252
253 } //namespace armnn
254