1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "QuantizedLstmLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
QuantizedLstmLayer(const char * name)17 QuantizedLstmLayer::QuantizedLstmLayer(const char* name)
18 : Layer(3, 2, LayerType::QuantizedLstm, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> QuantizedLstmLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 QuantizedLstmQueueDescriptor descriptor;
25
26 // QuantizedLstmLayer parameters - there are no optional params
27 descriptor.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights.get();
28 descriptor.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights.get();
29 descriptor.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights.get();
30 descriptor.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights.get();
31
32 descriptor.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights.get();
33 descriptor.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights.get();
34 descriptor.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights.get();
35 descriptor.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights.get();
36
37 descriptor.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias.get();
38 descriptor.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias.get();
39 descriptor.m_CellBias = m_QuantizedLstmParameters.m_CellBias.get();
40 descriptor.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias.get();
41
42 SetAdditionalInfo(descriptor);
43
44 return factory.CreateWorkload(LayerType::QuantizedLstm, descriptor, PrepInfoAndDesc(descriptor));
45 }
46
Clone(Graph & graph) const47 QuantizedLstmLayer* QuantizedLstmLayer::Clone(Graph& graph) const
48 {
49 auto layer = CloneBase<QuantizedLstmLayer>(graph, GetName());
50
51 layer->m_QuantizedLstmParameters.m_InputToInputWeights = m_QuantizedLstmParameters.m_InputToInputWeights ?
52 m_QuantizedLstmParameters.m_InputToInputWeights : nullptr;
53 layer->m_QuantizedLstmParameters.m_InputToForgetWeights = m_QuantizedLstmParameters.m_InputToForgetWeights ?
54 m_QuantizedLstmParameters.m_InputToForgetWeights : nullptr;
55 layer->m_QuantizedLstmParameters.m_InputToCellWeights = m_QuantizedLstmParameters.m_InputToCellWeights ?
56 m_QuantizedLstmParameters.m_InputToCellWeights : nullptr;
57 layer->m_QuantizedLstmParameters.m_InputToOutputWeights = m_QuantizedLstmParameters.m_InputToOutputWeights ?
58 m_QuantizedLstmParameters.m_InputToOutputWeights : nullptr;
59
60 layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = m_QuantizedLstmParameters.m_RecurrentToInputWeights ?
61 m_QuantizedLstmParameters.m_RecurrentToInputWeights : nullptr;
62 layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = m_QuantizedLstmParameters.m_RecurrentToForgetWeights
63 ? m_QuantizedLstmParameters.m_RecurrentToForgetWeights : nullptr;
64 layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = m_QuantizedLstmParameters.m_RecurrentToCellWeights ?
65 m_QuantizedLstmParameters.m_RecurrentToCellWeights : nullptr;
66 layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = m_QuantizedLstmParameters.m_RecurrentToOutputWeights
67 ? m_QuantizedLstmParameters.m_RecurrentToOutputWeights : nullptr;
68
69 layer->m_QuantizedLstmParameters.m_InputGateBias = m_QuantizedLstmParameters.m_InputGateBias ?
70 m_QuantizedLstmParameters.m_InputGateBias : nullptr;
71 layer->m_QuantizedLstmParameters.m_ForgetGateBias = m_QuantizedLstmParameters.m_ForgetGateBias ?
72 m_QuantizedLstmParameters.m_ForgetGateBias : nullptr;
73 layer->m_QuantizedLstmParameters.m_CellBias = m_QuantizedLstmParameters.m_CellBias ?
74 m_QuantizedLstmParameters.m_CellBias : nullptr;
75 layer->m_QuantizedLstmParameters.m_OutputGateBias = m_QuantizedLstmParameters.m_OutputGateBias ?
76 m_QuantizedLstmParameters.m_OutputGateBias : nullptr;
77
78 return std::move(layer);
79 }
80
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const81 std::vector<TensorShape> QuantizedLstmLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
82 {
83 ARMNN_ASSERT(inputShapes.size() == 3);
84
85 // Get input values for validation
86 unsigned int numBatches = inputShapes[0][0];
87 unsigned int outputSize = inputShapes[1][1];
88
89 std::vector<TensorShape> outShapes;
90 outShapes.push_back(TensorShape({numBatches, outputSize})); // cellStateOut
91 outShapes.push_back(TensorShape({numBatches, outputSize})); // output
92
93 return outShapes;
94 }
95
ValidateTensorShapesFromInputs()96 void QuantizedLstmLayer::ValidateTensorShapesFromInputs()
97 {
98 VerifyLayerConnections(3, CHECK_LOCATION());
99
100 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
101
102 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
103
104 auto inferredShapes = InferOutputShapes(
105 {
106 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), // input
107 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), // previousCellStateIn
108 GetInputSlot(2).GetConnection()->GetTensorInfo().GetShape() // previousOutputIn
109 });
110
111 ARMNN_ASSERT(inferredShapes.size() == 2);
112
113 // Check weights and bias for nullptr
114 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToInputWeights != nullptr,
115 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToInputWeights should not be null.");
116 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr,
117 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToForgetWeights should not be null.");
118 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToCellWeights != nullptr,
119 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToCellWeights should not be null.");
120 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr,
121 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputToOutputWeights should not be null.");
122
123 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr,
124 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToInputWeights should not be null.");
125 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr,
126 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToForgetWeights should not be null.");
127 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr,
128 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToCellWeights should not be null.");
129 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr,
130 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_RecurrentToOutputWeights should not be null.");
131
132 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_InputGateBias != nullptr,
133 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_InputGateBias should not be null.");
134 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_ForgetGateBias != nullptr,
135 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_ForgetGateBias should not be null.");
136 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_CellBias != nullptr,
137 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_CellBias should not be null.");
138 ARMNN_ASSERT_MSG(m_QuantizedLstmParameters.m_OutputGateBias != nullptr,
139 "QuantizedLstmLayer: m_QuantizedLstmParameters.m_OutputGateBias should not be null.");
140
141 // Check output TensorShape(s) match inferred shape
142 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "QuantizedLstmLayer");
143
144 ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
145 inferredShapes[1],
146 m_ShapeInferenceMethod,
147 "QuantizedLstmLayer",
148 1);
149 }
150
GetConstantTensorsByRef() const151 Layer::ImmutableConstantTensors QuantizedLstmLayer::GetConstantTensorsByRef() const
152 {
153 // For API stability DO NOT ALTER order and add new members to the end of vector
154 return
155 {
156 m_QuantizedLstmParameters.m_InputToInputWeights,
157 m_QuantizedLstmParameters.m_InputToForgetWeights,
158 m_QuantizedLstmParameters.m_InputToCellWeights,
159 m_QuantizedLstmParameters.m_InputToOutputWeights,
160
161 m_QuantizedLstmParameters.m_RecurrentToInputWeights,
162 m_QuantizedLstmParameters.m_RecurrentToForgetWeights,
163 m_QuantizedLstmParameters.m_RecurrentToCellWeights,
164 m_QuantizedLstmParameters.m_RecurrentToOutputWeights,
165
166 m_QuantizedLstmParameters.m_InputGateBias,
167 m_QuantizedLstmParameters.m_ForgetGateBias,
168 m_QuantizedLstmParameters.m_CellBias,
169 m_QuantizedLstmParameters.m_OutputGateBias
170 };
171 }
172
ExecuteStrategy(IStrategy & strategy) const173 void QuantizedLstmLayer::ExecuteStrategy(IStrategy& strategy) const
174 {
175 std::vector<ConstTensor> constTensors;
176
177 ManagedConstTensorHandle managedInputToInputWeights(m_QuantizedLstmParameters.m_InputToInputWeights);
178 ManagedConstTensorHandle managedInputToForgetWeights(m_QuantizedLstmParameters.m_InputToForgetWeights);
179 ManagedConstTensorHandle managedInputToCellWeights(m_QuantizedLstmParameters.m_InputToCellWeights);
180 ManagedConstTensorHandle managedInputToOutputWeights(m_QuantizedLstmParameters.m_InputToOutputWeights);
181
182 ManagedConstTensorHandle managedRecurrentToInputWeights(m_QuantizedLstmParameters.m_RecurrentToInputWeights);
183 ManagedConstTensorHandle managedRecurrentToForgetWeights(m_QuantizedLstmParameters.m_RecurrentToForgetWeights);
184 ManagedConstTensorHandle managedRecurrentToCellWeights(m_QuantizedLstmParameters.m_RecurrentToCellWeights);
185 ManagedConstTensorHandle managedRecurrentToOutputWeights(m_QuantizedLstmParameters.m_RecurrentToOutputWeights);
186
187 ManagedConstTensorHandle managedInputGateBias(m_QuantizedLstmParameters.m_InputGateBias);
188 ManagedConstTensorHandle managedForgetGateBias(m_QuantizedLstmParameters.m_ForgetGateBias);
189 ManagedConstTensorHandle managedCellBias(m_QuantizedLstmParameters.m_CellBias);
190 ManagedConstTensorHandle managedOutputGateBias(m_QuantizedLstmParameters.m_OutputGateBias);
191
192 // InputToX weight tensors
193 if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
194 {
195 constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
196 managedInputToInputWeights.Map()));
197 }
198
199 if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
200 {
201 constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
202 managedInputToForgetWeights.Map()));
203 }
204
205 if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
206 {
207 constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
208 managedInputToCellWeights.Map()));
209 }
210
211 if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
212 {
213 constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
214 managedInputToOutputWeights.Map()));
215 }
216
217 // RecurrentToX weight tensors
218 if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
219 {
220 constTensors.emplace_back(ConstTensor(
221 managedRecurrentToInputWeights.GetTensorInfo(),
222 managedRecurrentToInputWeights.Map()));
223 }
224
225 if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
226 {
227 constTensors.emplace_back(ConstTensor(
228 managedRecurrentToForgetWeights.GetTensorInfo(),
229 managedRecurrentToForgetWeights.Map()));
230 }
231
232 if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
233 {
234 constTensors.emplace_back(ConstTensor(
235 managedRecurrentToCellWeights.GetTensorInfo(),
236 managedRecurrentToCellWeights.Map()));
237 }
238
239 if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
240 {
241 constTensors.emplace_back(ConstTensor(
242 managedRecurrentToOutputWeights.GetTensorInfo(),
243 managedRecurrentToOutputWeights.Map()));
244 }
245
246 // Bias tensors
247 if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
248 {
249 constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
250 managedInputGateBias.Map()));
251 }
252
253 if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
254 {
255 constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
256 managedForgetGateBias.Map()));
257 }
258
259 if (m_QuantizedLstmParameters.m_CellBias != nullptr)
260 {
261 constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
262 managedCellBias.Map()));
263 }
264
265 if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
266 {
267 constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
268 managedOutputGateBias.Map()));
269 }
270
271
272 strategy.ExecuteStrategy(this, BaseDescriptor(), constTensors, GetName());
273 }
274
275 } // namespace armnn
276