xref: /aosp_15_r20/external/armnn/delegate/classic/src/ElementwiseBinary.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include "MultiLayerFacade.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "SharedFunctions.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
16*89c4ff92SAndroid Build Coastguard Worker #include "tensorflow/lite/delegates/utils.h"
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker 
ValidateAddOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)21*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateAddOperator(DelegateData& delegateData,
22*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteContext* tfLiteContext,
23*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo1,
24*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo2,
25*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& outputInfo)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
28*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
29*89c4ff92SAndroid Build Coastguard Worker     {
30*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::TensorInfo> infos { inputInfo1, inputInfo2, outputInfo };
31*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("ADD",
32*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
33*89c4ff92SAndroid Build Coastguard Worker                                    IsElementwiseBinarySupported,
34*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
35*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
36*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BackendId(),
37*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo1,
38*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo2,
39*89c4ff92SAndroid Build Coastguard Worker                                    outputInfo,
40*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BinaryOperation::Add);
41*89c4ff92SAndroid Build Coastguard Worker     };
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     validateFunc(outputInfo, isSupported);
44*89c4ff92SAndroid Build Coastguard Worker     return isSupported ? kTfLiteOk : kTfLiteError;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker 
ValidateDivOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)48*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
49*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteContext* tfLiteContext,
50*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo1,
51*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo2,
52*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& outputInfo)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
55*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
56*89c4ff92SAndroid Build Coastguard Worker     {
57*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("DIV",
58*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
59*89c4ff92SAndroid Build Coastguard Worker                                    IsElementwiseBinarySupported,
60*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
61*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
62*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BackendId(),
63*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo1,
64*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo2,
65*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
66*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BinaryOperation::Div);
67*89c4ff92SAndroid Build Coastguard Worker     };
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     validateFunc(outputInfo, isSupported);
70*89c4ff92SAndroid Build Coastguard Worker     return isSupported ? kTfLiteOk : kTfLiteError;
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker 
ValidateFloorDivOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)73*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateFloorDivOperator(DelegateData& delegateData,
74*89c4ff92SAndroid Build Coastguard Worker                                       TfLiteContext* tfLiteContext,
75*89c4ff92SAndroid Build Coastguard Worker                                       const armnn::TensorInfo& inputInfo1,
76*89c4ff92SAndroid Build Coastguard Worker                                       const armnn::TensorInfo& inputInfo2,
77*89c4ff92SAndroid Build Coastguard Worker                                       const armnn::TensorInfo& outputInfo)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker     // need first to validate that the div operator is supported
80*89c4ff92SAndroid Build Coastguard Worker     // then that the floor operator is supported
81*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus status = ValidateDivOperator(delegateData, tfLiteContext, inputInfo1, inputInfo2, outputInfo);
82*89c4ff92SAndroid Build Coastguard Worker     if (status != kTfLiteOk)
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         return status;
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker     // if the inputs and output of the div are all Signed32 we don't need to add the floor operator afterward.
87*89c4ff92SAndroid Build Coastguard Worker     if (AreAllSigned32(inputInfo1, inputInfo2, outputInfo))
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         return status;
90*89c4ff92SAndroid Build Coastguard Worker     }
91*89c4ff92SAndroid Build Coastguard Worker     // in case broadcasting is being done from one of the inputs to the div
92*89c4ff92SAndroid Build Coastguard Worker     // choose the full sized input tensor to pass to the floor validation routine
93*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo floorInputInfo = inputInfo1;
94*89c4ff92SAndroid Build Coastguard Worker     if (inputInfo1.GetNumDimensions() < inputInfo2.GetNumDimensions())
95*89c4ff92SAndroid Build Coastguard Worker     {
96*89c4ff92SAndroid Build Coastguard Worker         floorInputInfo = inputInfo2;
97*89c4ff92SAndroid Build Coastguard Worker     }
98*89c4ff92SAndroid Build Coastguard Worker     status = ValidateFloorOperator(delegateData, tfLiteContext, floorInputInfo, outputInfo);
99*89c4ff92SAndroid Build Coastguard Worker     return status;
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker 
ValidateMaximumOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)102*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData,
103*89c4ff92SAndroid Build Coastguard Worker                                      TfLiteContext* tfLiteContext,
104*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& inputInfo1,
105*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& inputInfo2,
106*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& outputInfo)
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
109*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
110*89c4ff92SAndroid Build Coastguard Worker     {
111*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("MAXIMUM",
112*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
113*89c4ff92SAndroid Build Coastguard Worker                                    IsElementwiseBinarySupported,
114*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
115*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
116*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BackendId(),
117*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo1,
118*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo2,
119*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
120*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BinaryOperation::Maximum);
121*89c4ff92SAndroid Build Coastguard Worker     };
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     validateFunc(outputInfo, isSupported);
124*89c4ff92SAndroid Build Coastguard Worker     return isSupported ? kTfLiteOk : kTfLiteError;
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker 
ValidateMinimumOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)127*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateMinimumOperator(DelegateData& delegateData,
128*89c4ff92SAndroid Build Coastguard Worker                                      TfLiteContext* tfLiteContext,
129*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& inputInfo1,
130*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& inputInfo2,
131*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& outputInfo)
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
134*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
135*89c4ff92SAndroid Build Coastguard Worker     {
136*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("MINIMUM",
137*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
138*89c4ff92SAndroid Build Coastguard Worker                                    IsElementwiseBinarySupported,
139*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
140*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
141*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BackendId(),
142*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo1,
143*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo2,
144*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
145*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BinaryOperation::Minimum);
146*89c4ff92SAndroid Build Coastguard Worker     };
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     validateFunc(outputInfo, isSupported);
149*89c4ff92SAndroid Build Coastguard Worker     return isSupported ? kTfLiteOk : kTfLiteError;
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker 
ValidateMulOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)152*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateMulOperator(DelegateData& delegateData,
153*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteContext* tfLiteContext,
154*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo1,
155*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo2,
156*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& outputInfo)
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
159*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
160*89c4ff92SAndroid Build Coastguard Worker     {
161*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("MUL",
162*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
163*89c4ff92SAndroid Build Coastguard Worker                                    IsElementwiseBinarySupported,
164*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
165*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
166*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BackendId(),
167*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo1,
168*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo2,
169*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
170*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BinaryOperation::Mul);
171*89c4ff92SAndroid Build Coastguard Worker     };
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker     validateFunc(outputInfo, isSupported);
174*89c4ff92SAndroid Build Coastguard Worker     return isSupported ? kTfLiteOk : kTfLiteError;
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker 
ValidateSubOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)177*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateSubOperator(DelegateData& delegateData,
178*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteContext* tfLiteContext,
179*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo1,
180*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& inputInfo2,
181*89c4ff92SAndroid Build Coastguard Worker                                  const armnn::TensorInfo& outputInfo)
182*89c4ff92SAndroid Build Coastguard Worker {
183*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
184*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
185*89c4ff92SAndroid Build Coastguard Worker     {
186*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("SUB",
187*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
188*89c4ff92SAndroid Build Coastguard Worker                                    IsElementwiseBinarySupported,
189*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
190*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
191*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BackendId(),
192*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo1,
193*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo2,
194*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
195*89c4ff92SAndroid Build Coastguard Worker                                    armnn::BinaryOperation::Sub);
196*89c4ff92SAndroid Build Coastguard Worker     };
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     validateFunc(outputInfo, isSupported);
199*89c4ff92SAndroid Build Coastguard Worker     return isSupported ? kTfLiteOk : kTfLiteError;
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker 
AddFloorDivLayer(DelegateData & delegateData,const armnn::TensorInfo & outputTensorInfo)202*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer(
203*89c4ff92SAndroid Build Coastguard Worker     DelegateData& delegateData,
204*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
207*89c4ff92SAndroid Build Coastguard Worker             armnn::BinaryOperation::Div);
208*89c4ff92SAndroid Build Coastguard Worker     // if the output of the div is Signed32 the Floor layer is not required
209*89c4ff92SAndroid Build Coastguard Worker     if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType())
210*89c4ff92SAndroid Build Coastguard Worker     {
211*89c4ff92SAndroid Build Coastguard Worker         return std::make_pair(divisionLayer, divisionLayer);
212*89c4ff92SAndroid Build Coastguard Worker     }
213*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = divisionLayer->GetOutputSlot(0);
214*89c4ff92SAndroid Build Coastguard Worker     outputSlot.SetTensorInfo(outputTensorInfo);
215*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* floorLayer = delegateData.m_Network->AddFloorLayer();
216*89c4ff92SAndroid Build Coastguard Worker     outputSlot.Connect(floorLayer->GetInputSlot(0));
217*89c4ff92SAndroid Build Coastguard Worker     return std::make_pair(divisionLayer, floorLayer);
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker 
VisitElementwiseBinaryOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t elementwiseBinaryOperatorCode)220*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
221*89c4ff92SAndroid Build Coastguard Worker                                             TfLiteContext* tfLiteContext,
222*89c4ff92SAndroid Build Coastguard Worker                                             TfLiteNode* tfLiteNode,
223*89c4ff92SAndroid Build Coastguard Worker                                             int nodeIndex,
224*89c4ff92SAndroid Build Coastguard Worker                                             int32_t elementwiseBinaryOperatorCode)
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
227*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
228*89c4ff92SAndroid Build Coastguard Worker 
229*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
230*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
231*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteInputTensor0))
232*89c4ff92SAndroid Build Coastguard Worker     {
233*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
234*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
235*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
236*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryOperatorCode, nodeIndex);
237*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
238*89c4ff92SAndroid Build Coastguard Worker     }
239*89c4ff92SAndroid Build Coastguard Worker 
240*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
241*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteInputTensor1))
242*89c4ff92SAndroid Build Coastguard Worker     {
243*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
244*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
245*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
246*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryOperatorCode, nodeIndex);
247*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
248*89c4ff92SAndroid Build Coastguard Worker     }
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
251*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteOutputTensor))
252*89c4ff92SAndroid Build Coastguard Worker     {
253*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
254*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
255*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
256*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryOperatorCode, nodeIndex);
257*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
258*89c4ff92SAndroid Build Coastguard Worker     }
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
261*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
264*89c4ff92SAndroid Build Coastguard Worker 
265*89c4ff92SAndroid Build Coastguard Worker     // Check if we need to expand the dims of the input tensor infos.
266*89c4ff92SAndroid Build Coastguard Worker     // This is required for a few of the backends.
267*89c4ff92SAndroid Build Coastguard Worker     if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
268*89c4ff92SAndroid Build Coastguard Worker     {
269*89c4ff92SAndroid Build Coastguard Worker         ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
270*89c4ff92SAndroid Build Coastguard Worker     }
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data);
273*89c4ff92SAndroid Build Coastguard Worker     TfLiteFusedActivation activationType = kTfLiteActNone;
274*89c4ff92SAndroid Build Coastguard Worker     if (tfLiteNodeParameters)
275*89c4ff92SAndroid Build Coastguard Worker     {
276*89c4ff92SAndroid Build Coastguard Worker         activationType = tfLiteNodeParameters->activation;
277*89c4ff92SAndroid Build Coastguard Worker         TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
278*89c4ff92SAndroid Build Coastguard Worker                                                                         outputTensorInfo, activationType);
279*89c4ff92SAndroid Build Coastguard Worker         if(activationStatus != kTfLiteOk)
280*89c4ff92SAndroid Build Coastguard Worker         {
281*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
282*89c4ff92SAndroid Build Coastguard Worker         }
283*89c4ff92SAndroid Build Coastguard Worker     }
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
286*89c4ff92SAndroid Build Coastguard Worker     {
287*89c4ff92SAndroid Build Coastguard Worker         switch(elementwiseBinaryOperatorCode)
288*89c4ff92SAndroid Build Coastguard Worker         {
289*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinAdd:
290*89c4ff92SAndroid Build Coastguard Worker                 return ValidateAddOperator(delegateData,
291*89c4ff92SAndroid Build Coastguard Worker                                            tfLiteContext,
292*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo0,
293*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo1,
294*89c4ff92SAndroid Build Coastguard Worker                                            outputTensorInfo);
295*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinDiv:
296*89c4ff92SAndroid Build Coastguard Worker                 return ValidateDivOperator(delegateData,
297*89c4ff92SAndroid Build Coastguard Worker                                            tfLiteContext,
298*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo0,
299*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo1,
300*89c4ff92SAndroid Build Coastguard Worker                                            outputTensorInfo);
301*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinFloorDiv:
302*89c4ff92SAndroid Build Coastguard Worker                 return ValidateFloorDivOperator(delegateData,
303*89c4ff92SAndroid Build Coastguard Worker                                                 tfLiteContext,
304*89c4ff92SAndroid Build Coastguard Worker                                                 inputTensorInfo0,
305*89c4ff92SAndroid Build Coastguard Worker                                                 inputTensorInfo1,
306*89c4ff92SAndroid Build Coastguard Worker                                                 outputTensorInfo);
307*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinMaximum:
308*89c4ff92SAndroid Build Coastguard Worker                 return ValidateMaximumOperator(delegateData,
309*89c4ff92SAndroid Build Coastguard Worker                                                tfLiteContext,
310*89c4ff92SAndroid Build Coastguard Worker                                                inputTensorInfo0,
311*89c4ff92SAndroid Build Coastguard Worker                                                inputTensorInfo1,
312*89c4ff92SAndroid Build Coastguard Worker                                                outputTensorInfo);
313*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinMinimum:
314*89c4ff92SAndroid Build Coastguard Worker                 return ValidateMinimumOperator(delegateData,
315*89c4ff92SAndroid Build Coastguard Worker                                                tfLiteContext,
316*89c4ff92SAndroid Build Coastguard Worker                                                inputTensorInfo0,
317*89c4ff92SAndroid Build Coastguard Worker                                                inputTensorInfo1,
318*89c4ff92SAndroid Build Coastguard Worker                                                outputTensorInfo);
319*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinMul:
320*89c4ff92SAndroid Build Coastguard Worker                 return ValidateMulOperator(delegateData,
321*89c4ff92SAndroid Build Coastguard Worker                                            tfLiteContext,
322*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo0,
323*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo1,
324*89c4ff92SAndroid Build Coastguard Worker                                            outputTensorInfo);
325*89c4ff92SAndroid Build Coastguard Worker             case kTfLiteBuiltinSub:
326*89c4ff92SAndroid Build Coastguard Worker                 return ValidateSubOperator(delegateData,
327*89c4ff92SAndroid Build Coastguard Worker                                            tfLiteContext,
328*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo0,
329*89c4ff92SAndroid Build Coastguard Worker                                            inputTensorInfo1,
330*89c4ff92SAndroid Build Coastguard Worker                                            outputTensorInfo);
331*89c4ff92SAndroid Build Coastguard Worker             default:
332*89c4ff92SAndroid Build Coastguard Worker                 return kTfLiteError;
333*89c4ff92SAndroid Build Coastguard Worker         }
334*89c4ff92SAndroid Build Coastguard Worker     }
335*89c4ff92SAndroid Build Coastguard Worker 
336*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* elementwiseBinaryLayer = nullptr;
337*89c4ff92SAndroid Build Coastguard Worker     MultiLayerFacade multiLayer;
338*89c4ff92SAndroid Build Coastguard Worker     switch(elementwiseBinaryOperatorCode)
339*89c4ff92SAndroid Build Coastguard Worker     {
340*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinAdd:
341*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
342*89c4ff92SAndroid Build Coastguard Worker                     armnn::BinaryOperation::Add);
343*89c4ff92SAndroid Build Coastguard Worker             break;
344*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinDiv:
345*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
346*89c4ff92SAndroid Build Coastguard Worker                     armnn::BinaryOperation::Div);
347*89c4ff92SAndroid Build Coastguard Worker             break;
348*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinFloorDiv:
349*89c4ff92SAndroid Build Coastguard Worker             {
350*89c4ff92SAndroid Build Coastguard Worker                 auto layers = AddFloorDivLayer(delegateData, outputTensorInfo);
351*89c4ff92SAndroid Build Coastguard Worker                 multiLayer.AssignValues(layers.first, layers.second);
352*89c4ff92SAndroid Build Coastguard Worker                 elementwiseBinaryLayer = &multiLayer;
353*89c4ff92SAndroid Build Coastguard Worker             }
354*89c4ff92SAndroid Build Coastguard Worker             break;
355*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinMaximum:
356*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
357*89c4ff92SAndroid Build Coastguard Worker                     armnn::BinaryOperation::Maximum);
358*89c4ff92SAndroid Build Coastguard Worker             break;
359*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinMinimum:
360*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
361*89c4ff92SAndroid Build Coastguard Worker                     armnn::BinaryOperation::Minimum);
362*89c4ff92SAndroid Build Coastguard Worker             break;
363*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinMul:
364*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
365*89c4ff92SAndroid Build Coastguard Worker                     armnn::BinaryOperation::Mul);
366*89c4ff92SAndroid Build Coastguard Worker             break;
367*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinSub:
368*89c4ff92SAndroid Build Coastguard Worker             elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
369*89c4ff92SAndroid Build Coastguard Worker                     armnn::BinaryOperation::Sub);
370*89c4ff92SAndroid Build Coastguard Worker             break;
371*89c4ff92SAndroid Build Coastguard Worker         default:
372*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
373*89c4ff92SAndroid Build Coastguard Worker     }
374*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(elementwiseBinaryLayer != nullptr);
375*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = elementwiseBinaryLayer->GetOutputSlot(0);
376*89c4ff92SAndroid Build Coastguard Worker     outputSlot.SetTensorInfo(outputTensorInfo);
377*89c4ff92SAndroid Build Coastguard Worker 
378*89c4ff92SAndroid Build Coastguard Worker     auto inputsTensorsProcess = ProcessInputs(elementwiseBinaryLayer,
379*89c4ff92SAndroid Build Coastguard Worker                                               delegateData,
380*89c4ff92SAndroid Build Coastguard Worker                                               tfLiteContext,
381*89c4ff92SAndroid Build Coastguard Worker                                               tfLiteNode);
382*89c4ff92SAndroid Build Coastguard Worker     if (inputsTensorsProcess == kTfLiteError)
383*89c4ff92SAndroid Build Coastguard Worker     {
384*89c4ff92SAndroid Build Coastguard Worker         return inputsTensorsProcess;
385*89c4ff92SAndroid Build Coastguard Worker     }
386*89c4ff92SAndroid Build Coastguard Worker 
387*89c4ff92SAndroid Build Coastguard Worker     if(Connect(elementwiseBinaryLayer, tfLiteNode, delegateData) != kTfLiteOk)
388*89c4ff92SAndroid Build Coastguard Worker     {
389*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
390*89c4ff92SAndroid Build Coastguard Worker     }
391*89c4ff92SAndroid Build Coastguard Worker 
392*89c4ff92SAndroid Build Coastguard Worker     if (!tfLiteNodeParameters)
393*89c4ff92SAndroid Build Coastguard Worker     {
394*89c4ff92SAndroid Build Coastguard Worker         // No Activation
395*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteOk;
396*89c4ff92SAndroid Build Coastguard Worker     }
397*89c4ff92SAndroid Build Coastguard Worker     // Check and Create Activation
398*89c4ff92SAndroid Build Coastguard Worker     return FusedActivation(tfLiteContext, tfLiteNode, activationType, elementwiseBinaryLayer, 0, delegateData);
399*89c4ff92SAndroid Build Coastguard Worker }
400*89c4ff92SAndroid Build Coastguard Worker 
401*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
402