xref: /aosp_15_r20/external/armnn/src/armnn/layers/QuantizedLstmLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QuantizedLstmLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13 
14 namespace armnn
15 {
16 
QuantizedLstmLayer(const char * name)17 QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
18     : Layer(3, 2, LayerType::QuantizedLstm, name)
19 {
20 }
21 
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24     QuantizedLstmQueueDescriptor descriptor;
25 
26     // QuantizedLstmLayer parameters - there are no optional params
27     descriptor.m_InputToInputWeights  = m_QuantizedLstmParameters.m_InputToInputWeights.get();
28     descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
29     descriptor.m_InputToCellWeights   = m_QuantizedLstmParameters.m_InputToCellWeights.get();
30     descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
31 
32     descriptor.m_RecurrentToInputWeights  = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
33     descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
34     descriptor.m_RecurrentToCellWeights   = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
35     descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
36 
37     descriptor.m_InputGateBias  = m_QuantizedLstmParameters.m_InputGateBias.get();
38     descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
39     descriptor.m_CellBias       = m_QuantizedLstmParameters.m_CellBias.get();
40     descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
41 
42     SetAdditionalInfo(descriptor);
43 
44     return factory.CreateWorkload(LayerType::QuantizedLstm, descriptor, PrepInfoAndDesc(descriptor));
45 }
46 
Clone(Graph & graph) const47 QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
48 {
49     auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
50 
51     layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
52             m_QuantizedLstmParameters.m_InputToInputWeights : nullptr;
53     layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
54             m_QuantizedLstmParameters.m_InputToForgetWeights : nullptr;
55     layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
56             m_QuantizedLstmParameters.m_InputToCellWeights : nullptr;
57     layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
58             m_QuantizedLstmParameters.m_InputToOutputWeights : nullptr;
59 
60     layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
61             m_QuantizedLstmParameters.m_RecurrentToInputWeights : nullptr;
62     layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
63             ? m_QuantizedLstmParameters.m_RecurrentToForgetWeights : nullptr;
64     layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
65             m_QuantizedLstmParameters.m_RecurrentToCellWeights : nullptr;
66     layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
67             ? m_QuantizedLstmParameters.m_RecurrentToOutputWeights : nullptr;
68 
69     layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
70             m_QuantizedLstmParameters.m_InputGateBias : nullptr;
71     layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
72             m_QuantizedLstmParameters.m_ForgetGateBias : nullptr;
73     layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
74             m_QuantizedLstmParameters.m_CellBias : nullptr;
75     layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
76             m_QuantizedLstmParameters.m_OutputGateBias : nullptr;
77 
78     return std::move(layer);
79 }
80 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const81 std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
82 {
83     ARMNN_ASSERT(inputShapes.size() == 3);
84 
85     // Get input values for validation
86     unsigned int numBatches = inputShapes[0][0];
87     unsigned int outputSize = inputShapes[1][1];
88 
89     std::vector<TensorShape> outShapes;
90     outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
91     outShapes.push_back(TensorShape({numBatches, outputSize})); // output
92 
93     return outShapes;
94 }
95 
ValidateTensorShapesFromInputs()96 void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
97 {
98     VerifyLayerConnections(3, CHECK_LOCATION());
99 
100     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
101 
102     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
103 
104     auto inferredShapes = InferOutputShapes(
105     {
106         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
107         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
108         GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape()  // previousOutputIn
109     });
110 
111     ARMNN_ASSERT(inferredShapes.size() == 2);
112 
113     // Check weights and bias for nullptr
114     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
115                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
116     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
117                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
118     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
119                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
120     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
121                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
122 
123     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
124                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
125     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
126                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
127     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
128                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
129     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
130                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
131 
132     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
133                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
134     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
135                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
136     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
137                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
138     ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
139                      "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
140 
141     // Check output TensorShape(s) match inferred shape
142     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QuantizedLstmLayer");
143 
144     ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
145                          inferredShapes[1],
146                          m_ShapeInferenceMethod,
147                          "QuantizedLstmLayer",
148                          1);
149 }
150 
GetConstantTensorsByRef() const151 Layer::ImmutableConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef() const
152 {
153     // For API stability DO NOT ALTER order and add new members to the end of vector
154     return
155     {
156         m_QuantizedLstmParameters.m_InputToInputWeights,
157         m_QuantizedLstmParameters.m_InputToForgetWeights,
158         m_QuantizedLstmParameters.m_InputToCellWeights,
159         m_QuantizedLstmParameters.m_InputToOutputWeights,
160 
161         m_QuantizedLstmParameters.m_RecurrentToInputWeights,
162         m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
163         m_QuantizedLstmParameters.m_RecurrentToCellWeights,
164         m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
165 
166         m_QuantizedLstmParameters.m_InputGateBias,
167         m_QuantizedLstmParameters.m_ForgetGateBias,
168         m_QuantizedLstmParameters.m_CellBias,
169         m_QuantizedLstmParameters.m_OutputGateBias
170     };
171 }
172 
ExecuteStrategy(IStrategy & strategy) const173 void QuantizedLstmLayer::ExecuteStrategy(IStrategy& strategy) const
174 {
175     std::vector<ConstTensor> constTensors;
176 
177     ManagedConstTensorHandle managedInputToInputWeights(m_QuantizedLstmParameters.m_InputToInputWeights);
178     ManagedConstTensorHandle managedInputToForgetWeights(m_QuantizedLstmParameters.m_InputToForgetWeights);
179     ManagedConstTensorHandle managedInputToCellWeights(m_QuantizedLstmParameters.m_InputToCellWeights);
180     ManagedConstTensorHandle managedInputToOutputWeights(m_QuantizedLstmParameters.m_InputToOutputWeights);
181 
182     ManagedConstTensorHandle managedRecurrentToInputWeights(m_QuantizedLstmParameters.m_RecurrentToInputWeights);
183     ManagedConstTensorHandle managedRecurrentToForgetWeights(m_QuantizedLstmParameters.m_RecurrentToForgetWeights);
184     ManagedConstTensorHandle managedRecurrentToCellWeights(m_QuantizedLstmParameters.m_RecurrentToCellWeights);
185     ManagedConstTensorHandle managedRecurrentToOutputWeights(m_QuantizedLstmParameters.m_RecurrentToOutputWeights);
186 
187     ManagedConstTensorHandle managedInputGateBias(m_QuantizedLstmParameters.m_InputGateBias);
188     ManagedConstTensorHandle managedForgetGateBias(m_QuantizedLstmParameters.m_ForgetGateBias);
189     ManagedConstTensorHandle managedCellBias(m_QuantizedLstmParameters.m_CellBias);
190     ManagedConstTensorHandle managedOutputGateBias(m_QuantizedLstmParameters.m_OutputGateBias);
191 
192     // InputToX weight tensors
193     if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
194     {
195         constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
196                                               managedInputToInputWeights.Map()));
197     }
198 
199     if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
200     {
201         constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
202                                               managedInputToForgetWeights.Map()));
203     }
204 
205     if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
206     {
207         constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
208                                               managedInputToCellWeights.Map()));
209     }
210 
211     if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
212     {
213         constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
214                                               managedInputToOutputWeights.Map()));
215     }
216 
217     // RecurrentToX weight tensors
218     if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
219     {
220         constTensors.emplace_back(ConstTensor(
221                 managedRecurrentToInputWeights.GetTensorInfo(),
222                 managedRecurrentToInputWeights.Map()));
223     }
224 
225     if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
226     {
227         constTensors.emplace_back(ConstTensor(
228                 managedRecurrentToForgetWeights.GetTensorInfo(),
229                 managedRecurrentToForgetWeights.Map()));
230     }
231 
232     if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
233     {
234         constTensors.emplace_back(ConstTensor(
235                 managedRecurrentToCellWeights.GetTensorInfo(),
236                 managedRecurrentToCellWeights.Map()));
237     }
238 
239     if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
240     {
241         constTensors.emplace_back(ConstTensor(
242                 managedRecurrentToOutputWeights.GetTensorInfo(),
243                 managedRecurrentToOutputWeights.Map()));
244     }
245 
246     // Bias tensors
247     if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
248     {
249         constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
250                                               managedInputGateBias.Map()));
251     }
252 
253     if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
254     {
255         constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
256                                               managedForgetGateBias.Map()));
257     }
258 
259     if (m_QuantizedLstmParameters.m_CellBias != nullptr)
260     {
261         constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
262                                               managedCellBias.Map()));
263     }
264 
265     if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
266     {
267         constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
268                                               managedOutputGateBias.Map()));
269     }
270 
271 
272     strategy.ExecuteStrategy(this, BaseDescriptor(), constTensors, GetName());
273 }
274 
275 } // namespace armnn
276