xref: /aosp_15_r20/external/armnn/delegate/classic/src/Lstm.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <armnn/LstmParams.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13 
14 #include <tensorflow/lite/builtin_ops.h>
15 #include <tensorflow/lite/c/builtin_op_data.h>
16 #include <tensorflow/lite/c/common.h>
17 #include <tensorflow/lite/minimal_logging.h>
18 
19 namespace armnnDelegate
20 {
21 
VisitLstmOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)22 TfLiteStatus VisitLstmOperator(DelegateData& delegateData,
23                                TfLiteContext* tfLiteContext,
24                                TfLiteNode* tfLiteNode,
25                                int nodeIndex,
26                                int32_t operatorCode)
27 {
28     auto numInputs = tfLiteNode->inputs->size;
29     if (numInputs < 2)
30     {
31         TF_LITE_MAYBE_KERNEL_LOG(
32                 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
33                 2, numInputs, nodeIndex);
34         return kTfLiteError;
35     }
36 
37     const auto nodeParams = reinterpret_cast<TfLiteLSTMParams*>(tfLiteNode->builtin_data);
38     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
39 
40     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
41     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
42     {
43         return kTfLiteError;
44     }
45 
46     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
47     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
48     {
49         return kTfLiteError;
50     }
51 
52     // Set the params structure for the AddLstmLayer call
53     armnn::LstmInputParams params;
54 
55     if (IsOptionalOperandPresent(tfLiteNode, 1))
56     {
57         params.m_InputToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 1);
58     }
59 
60     params.m_InputToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 2);
61     params.m_InputToCellWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 3);
62     params.m_InputToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 4);
63 
64     // Recurrent weight tensors of size {n_cell, n_output}
65     if (IsOptionalOperandPresent(tfLiteNode, 5))
66     {
67         params.m_RecurrentToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 5);
68     }
69 
70     params.m_RecurrentToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 6);
71     params.m_RecurrentToCellWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 7);
72     params.m_RecurrentToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 8);
73 
74     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
75     if (IsOptionalOperandPresent(tfLiteNode, 9))
76     {
77         params.m_CellToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 9);
78     }
79 
80     if (IsOptionalOperandPresent(tfLiteNode, 10))
81     {
82         params.m_CellToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 10);
83     }
84 
85     if (IsOptionalOperandPresent(tfLiteNode, 11))
86     {
87         params.m_CellToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 11);
88     }
89 
90     // Gates bias tensors of size {n_cell}
91     if (IsOptionalOperandPresent(tfLiteNode, 12))
92     {
93         params.m_InputGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 12);
94     }
95 
96     params.m_ForgetGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 13);
97     params.m_CellBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 14);
98     params.m_OutputGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 15);
99 
100     // Projection weight tensor of size {n_output, n_cell}
101     if (IsOptionalOperandPresent(tfLiteNode, 16))
102     {
103         params.m_ProjectionWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 16);
104     }
105     // Projection bias tensor of size {n_output}
106     if (IsOptionalOperandPresent(tfLiteNode, 17))
107     {
108         params.m_ProjectionBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 17);
109     }
110 
111     // These state tensors are defined as variable tensors, and will be modified by this op.
112     armnn::TensorInfo outputStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[18]]);
113     armnn::TensorInfo cellStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[19]]);
114 
115     // Layer norm coefficient tensors of size {n_cell}, representing a diagonal matrix.
116     if (IsOptionalOperandPresent(tfLiteNode, 20))
117     {
118         params.m_InputLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 20);
119     }
120 
121     if (IsOptionalOperandPresent(tfLiteNode, 21))
122     {
123         params.m_ForgetLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 21);
124     }
125 
126     if (IsOptionalOperandPresent(tfLiteNode, 22))
127     {
128         params.m_CellLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 22);
129     }
130 
131     if (IsOptionalOperandPresent(tfLiteNode, 23))
132     {
133         params.m_OutputLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 23);
134     }
135 
136     // set the layer descriptor
137     armnn::LstmDescriptor desc;
138     desc.m_ActivationFunc    = NonNegative(nodeParams->activation, nodeIndex);
139     desc.m_ClippingThresCell = nodeParams->cell_clip;
140     desc.m_ClippingThresProj = nodeParams->proj_clip;
141     desc.m_CifgEnabled       = (params.m_InputToInputWeights == nullptr
142                                 || params.m_RecurrentToInputWeights == nullptr
143                                 || params.m_InputGateBias == nullptr);
144     desc.m_PeepholeEnabled   = (params.m_CellToForgetWeights != nullptr || params.m_CellToOutputWeights != nullptr);
145     desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
146     desc.m_LayerNormEnabled  = (params.m_InputLayerNormWeights != nullptr
147                                 || params.m_ForgetLayerNormWeights != nullptr
148                                 || params.m_CellLayerNormWeights != nullptr
149                                 || params.m_OutputLayerNormWeights != nullptr);
150 
151     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
152     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
153 
154     unsigned int batchSize  = inputTensorInfo.GetShape()[0];
155     unsigned int outputSize = outputTensorInfo.GetShape()[1];
156     unsigned int numUnits   = cellStateInInfo.GetShape()[1];
157 
158     armnn::DataType dataType = inputTensorInfo.GetDataType();
159     float qScale = inputTensorInfo.GetQuantizationScale();
160     float qOffset = inputTensorInfo.GetQuantizationOffset();
161 
162     armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 3}, dataType, qScale, qOffset);
163     if (!desc.m_CifgEnabled)
164     {
165         scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
166     }
167     armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
168     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
169 
170     armnn::LstmInputParamsInfo paramsInfo;
171     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
172     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
173     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
174     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
175     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
176     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
177     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
178     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
179     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
180 
181     if (!desc.m_CifgEnabled)
182     {
183         paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
184         paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
185         if (params.m_CellToInputWeights != nullptr)
186         {
187             paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
188         }
189         paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
190     }
191 
192     if (desc.m_ProjectionEnabled)
193     {
194         paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
195         if (params.m_ProjectionBias != nullptr)
196         {
197             paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
198         }
199     }
200 
201     if (desc.m_PeepholeEnabled)
202     {
203         paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
204         paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
205     }
206 
207     if (desc.m_LayerNormEnabled)
208     {
209         if(!desc.m_CifgEnabled)
210         {
211             paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
212         }
213         paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
214         paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
215         paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
216     }
217 
218     bool isSupported = false;
219     armnn::BackendId setBackend;
220     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
221     {
222         FORWARD_LAYER_SUPPORT_FUNC("LSTM",
223                                    tfLiteContext,
224                                    IsLstmSupported,
225                                    delegateData.m_Backends,
226                                    isSupported,
227                                    setBackend,
228                                    inputTensorInfo,
229                                    outputStateInInfo,
230                                    cellStateInInfo,
231                                    scratchBufferTensorInfo,
232                                    outputStateOutTensorInfo,
233                                    cellStateOutTensorInfo,
234                                    outputInfo,
235                                    desc,
236                                    paramsInfo);
237     };
238 
239     if (!delegateData.m_Network)
240     {
241         validateFunc(outputTensorInfo, isSupported);
242         return isSupported ? kTfLiteOk : kTfLiteError;
243     }
244 
245     armnn::IConnectableLayer* layer = delegateData.m_Network->AddLstmLayer(desc, params);
246     layer->SetBackendId(setBackend);
247     ARMNN_ASSERT(layer != nullptr);
248 
249     layer->GetOutputSlot(0).SetTensorInfo(scratchBufferTensorInfo);
250     layer->GetOutputSlot(1).SetTensorInfo(outputStateOutTensorInfo);
251     layer->GetOutputSlot(2).SetTensorInfo(cellStateOutTensorInfo);
252     layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo);
253 
254     // Connect the inputs
255     // input_layer
256     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(layer->GetInputSlot(0));
257     // cellStateIn
258     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[18]]->Connect(layer->GetInputSlot(1));
259     //outputStateIn
260     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2));
261 
262     // In the test_model there is only 1 Output
263     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(1);
264     delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot;
265     return kTfLiteOk;
266 }
267 
268 } // namespace armnnDelegate