xref: /aosp_15_r20/external/armnn/delegate/classic/src/Lstm.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/LstmParams.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
17*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
VisitLstmOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)22*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitLstmOperator(DelegateData& delegateData,
23*89c4ff92SAndroid Build Coastguard Worker                                TfLiteContext* tfLiteContext,
24*89c4ff92SAndroid Build Coastguard Worker                                TfLiteNode* tfLiteNode,
25*89c4ff92SAndroid Build Coastguard Worker                                int nodeIndex,
26*89c4ff92SAndroid Build Coastguard Worker                                int32_t operatorCode)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker     auto numInputs = tfLiteNode->inputs->size;
29*89c4ff92SAndroid Build Coastguard Worker     if (numInputs < 2)
30*89c4ff92SAndroid Build Coastguard Worker     {
31*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
32*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
33*89c4ff92SAndroid Build Coastguard Worker                 2, numInputs, nodeIndex);
34*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
35*89c4ff92SAndroid Build Coastguard Worker     }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     const auto nodeParams = reinterpret_cast<TfLiteLSTMParams*>(tfLiteNode->builtin_data);
38*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
41*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
42*89c4ff92SAndroid Build Coastguard Worker     {
43*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
44*89c4ff92SAndroid Build Coastguard Worker     }
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
47*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
48*89c4ff92SAndroid Build Coastguard Worker     {
49*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
50*89c4ff92SAndroid Build Coastguard Worker     }
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     // Set the params structure for the AddLstmLayer call
53*89c4ff92SAndroid Build Coastguard Worker     armnn::LstmInputParams params;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 1))
56*89c4ff92SAndroid Build Coastguard Worker     {
57*89c4ff92SAndroid Build Coastguard Worker         params.m_InputToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 1);
58*89c4ff92SAndroid Build Coastguard Worker     }
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     params.m_InputToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 2);
61*89c4ff92SAndroid Build Coastguard Worker     params.m_InputToCellWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 3);
62*89c4ff92SAndroid Build Coastguard Worker     params.m_InputToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 4);
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     // Recurrent weight tensors of size {n_cell, n_output}
65*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 5))
66*89c4ff92SAndroid Build Coastguard Worker     {
67*89c4ff92SAndroid Build Coastguard Worker         params.m_RecurrentToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 5);
68*89c4ff92SAndroid Build Coastguard Worker     }
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     params.m_RecurrentToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 6);
71*89c4ff92SAndroid Build Coastguard Worker     params.m_RecurrentToCellWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 7);
72*89c4ff92SAndroid Build Coastguard Worker     params.m_RecurrentToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 8);
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
75*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 9))
76*89c4ff92SAndroid Build Coastguard Worker     {
77*89c4ff92SAndroid Build Coastguard Worker         params.m_CellToInputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 9);
78*89c4ff92SAndroid Build Coastguard Worker     }
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 10))
81*89c4ff92SAndroid Build Coastguard Worker     {
82*89c4ff92SAndroid Build Coastguard Worker         params.m_CellToForgetWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 10);
83*89c4ff92SAndroid Build Coastguard Worker     }
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 11))
86*89c4ff92SAndroid Build Coastguard Worker     {
87*89c4ff92SAndroid Build Coastguard Worker         params.m_CellToOutputWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 11);
88*89c4ff92SAndroid Build Coastguard Worker     }
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     // Gates bias tensors of size {n_cell}
91*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 12))
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         params.m_InputGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 12);
94*89c4ff92SAndroid Build Coastguard Worker     }
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     params.m_ForgetGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 13);
97*89c4ff92SAndroid Build Coastguard Worker     params.m_CellBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 14);
98*89c4ff92SAndroid Build Coastguard Worker     params.m_OutputGateBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 15);
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     // Projection weight tensor of size {n_output, n_cell}
101*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 16))
102*89c4ff92SAndroid Build Coastguard Worker     {
103*89c4ff92SAndroid Build Coastguard Worker         params.m_ProjectionWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 16);
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker     // Projection bias tensor of size {n_output}
106*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 17))
107*89c4ff92SAndroid Build Coastguard Worker     {
108*89c4ff92SAndroid Build Coastguard Worker         params.m_ProjectionBias = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 17);
109*89c4ff92SAndroid Build Coastguard Worker     }
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     // These state tensors are defined as variable tensors, and will be modified by this op.
112*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[18]]);
113*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo cellStateInInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->inputs->data[19]]);
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     // Layer norm coefficient tensors of size {n_cell}, representing a diagonal matrix.
116*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 20))
117*89c4ff92SAndroid Build Coastguard Worker     {
118*89c4ff92SAndroid Build Coastguard Worker         params.m_InputLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 20);
119*89c4ff92SAndroid Build Coastguard Worker     }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 21))
122*89c4ff92SAndroid Build Coastguard Worker     {
123*89c4ff92SAndroid Build Coastguard Worker         params.m_ForgetLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 21);
124*89c4ff92SAndroid Build Coastguard Worker     }
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 22))
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         params.m_CellLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 22);
129*89c4ff92SAndroid Build Coastguard Worker     }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(tfLiteNode, 23))
132*89c4ff92SAndroid Build Coastguard Worker     {
133*89c4ff92SAndroid Build Coastguard Worker         params.m_OutputLayerNormWeights = GetConstTensorForTfLiteTensor(tfLiteTensors, tfLiteNode, 23);
134*89c4ff92SAndroid Build Coastguard Worker     }
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker     // set the layer descriptor
137*89c4ff92SAndroid Build Coastguard Worker     armnn::LstmDescriptor desc;
138*89c4ff92SAndroid Build Coastguard Worker     desc.m_ActivationFunc    = NonNegative(nodeParams->activation, nodeIndex);
139*89c4ff92SAndroid Build Coastguard Worker     desc.m_ClippingThresCell = nodeParams->cell_clip;
140*89c4ff92SAndroid Build Coastguard Worker     desc.m_ClippingThresProj = nodeParams->proj_clip;
141*89c4ff92SAndroid Build Coastguard Worker     desc.m_CifgEnabled       = (params.m_InputToInputWeights == nullptr
142*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_RecurrentToInputWeights == nullptr
143*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_InputGateBias == nullptr);
144*89c4ff92SAndroid Build Coastguard Worker     desc.m_PeepholeEnabled   = (params.m_CellToForgetWeights != nullptr || params.m_CellToOutputWeights != nullptr);
145*89c4ff92SAndroid Build Coastguard Worker     desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
146*89c4ff92SAndroid Build Coastguard Worker     desc.m_LayerNormEnabled  = (params.m_InputLayerNormWeights != nullptr
147*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_ForgetLayerNormWeights != nullptr
148*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_CellLayerNormWeights != nullptr
149*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_OutputLayerNormWeights != nullptr);
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
152*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     unsigned int batchSize  = inputTensorInfo.GetShape()[0];
155*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputSize = outputTensorInfo.GetShape()[1];
156*89c4ff92SAndroid Build Coastguard Worker     unsigned int numUnits   = cellStateInInfo.GetShape()[1];
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     armnn::DataType dataType = inputTensorInfo.GetDataType();
159*89c4ff92SAndroid Build Coastguard Worker     float qScale = inputTensorInfo.GetQuantizationScale();
160*89c4ff92SAndroid Build Coastguard Worker     float qOffset = inputTensorInfo.GetQuantizationOffset();
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 3}, dataType, qScale, qOffset);
163*89c4ff92SAndroid Build Coastguard Worker     if (!desc.m_CifgEnabled)
164*89c4ff92SAndroid Build Coastguard Worker     {
165*89c4ff92SAndroid Build Coastguard Worker         scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
166*89c4ff92SAndroid Build Coastguard Worker     }
167*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
168*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker     armnn::LstmInputParamsInfo paramsInfo;
171*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
172*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
173*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
174*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
175*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
176*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
177*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
178*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
179*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     if (!desc.m_CifgEnabled)
182*89c4ff92SAndroid Build Coastguard Worker     {
183*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
184*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
185*89c4ff92SAndroid Build Coastguard Worker         if (params.m_CellToInputWeights != nullptr)
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
188*89c4ff92SAndroid Build Coastguard Worker         }
189*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
190*89c4ff92SAndroid Build Coastguard Worker     }
191*89c4ff92SAndroid Build Coastguard Worker 
192*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_ProjectionEnabled)
193*89c4ff92SAndroid Build Coastguard Worker     {
194*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
195*89c4ff92SAndroid Build Coastguard Worker         if (params.m_ProjectionBias != nullptr)
196*89c4ff92SAndroid Build Coastguard Worker         {
197*89c4ff92SAndroid Build Coastguard Worker             paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
198*89c4ff92SAndroid Build Coastguard Worker         }
199*89c4ff92SAndroid Build Coastguard Worker     }
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_PeepholeEnabled)
202*89c4ff92SAndroid Build Coastguard Worker     {
203*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
204*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
205*89c4ff92SAndroid Build Coastguard Worker     }
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_LayerNormEnabled)
208*89c4ff92SAndroid Build Coastguard Worker     {
209*89c4ff92SAndroid Build Coastguard Worker         if(!desc.m_CifgEnabled)
210*89c4ff92SAndroid Build Coastguard Worker         {
211*89c4ff92SAndroid Build Coastguard Worker             paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
212*89c4ff92SAndroid Build Coastguard Worker         }
213*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
214*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
215*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
216*89c4ff92SAndroid Build Coastguard Worker     }
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
219*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
220*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
221*89c4ff92SAndroid Build Coastguard Worker     {
222*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("LSTM",
223*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
224*89c4ff92SAndroid Build Coastguard Worker                                    IsLstmSupported,
225*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
226*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
227*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
228*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
229*89c4ff92SAndroid Build Coastguard Worker                                    outputStateInInfo,
230*89c4ff92SAndroid Build Coastguard Worker                                    cellStateInInfo,
231*89c4ff92SAndroid Build Coastguard Worker                                    scratchBufferTensorInfo,
232*89c4ff92SAndroid Build Coastguard Worker                                    outputStateOutTensorInfo,
233*89c4ff92SAndroid Build Coastguard Worker                                    cellStateOutTensorInfo,
234*89c4ff92SAndroid Build Coastguard Worker                                    outputInfo,
235*89c4ff92SAndroid Build Coastguard Worker                                    desc,
236*89c4ff92SAndroid Build Coastguard Worker                                    paramsInfo);
237*89c4ff92SAndroid Build Coastguard Worker     };
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
240*89c4ff92SAndroid Build Coastguard Worker     {
241*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputTensorInfo, isSupported);
242*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
243*89c4ff92SAndroid Build Coastguard Worker     }
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = delegateData.m_Network->AddLstmLayer(desc, params);
246*89c4ff92SAndroid Build Coastguard Worker     layer->SetBackendId(setBackend);
247*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
248*89c4ff92SAndroid Build Coastguard Worker 
249*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(scratchBufferTensorInfo);
250*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(1).SetTensorInfo(outputStateOutTensorInfo);
251*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(2).SetTensorInfo(cellStateOutTensorInfo);
252*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(3).SetTensorInfo(outputTensorInfo);
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker     // Connect the inputs
255*89c4ff92SAndroid Build Coastguard Worker     // input_layer
256*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(layer->GetInputSlot(0));
257*89c4ff92SAndroid Build Coastguard Worker     // cellStateIn
258*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[18]]->Connect(layer->GetInputSlot(1));
259*89c4ff92SAndroid Build Coastguard Worker     //outputStateIn
260*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2));
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker     // In the test_model there is only 1 Output
263*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(1);
264*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot;
265*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
266*89c4ff92SAndroid Build Coastguard Worker }
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate