1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include "MultiLayerFacade.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "SharedFunctions.hpp"
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
16*89c4ff92SAndroid Build Coastguard Worker #include "tensorflow/lite/delegates/utils.h"
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker
ValidateAddOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)21*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateAddOperator(DelegateData& delegateData,
22*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
23*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
24*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
25*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
28*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorInfo> infos { inputInfo1, inputInfo2, outputInfo };
31*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("ADD",
32*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
33*89c4ff92SAndroid Build Coastguard Worker IsElementwiseBinarySupported,
34*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
35*89c4ff92SAndroid Build Coastguard Worker isSupported,
36*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
37*89c4ff92SAndroid Build Coastguard Worker inputInfo1,
38*89c4ff92SAndroid Build Coastguard Worker inputInfo2,
39*89c4ff92SAndroid Build Coastguard Worker outputInfo,
40*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Add);
41*89c4ff92SAndroid Build Coastguard Worker };
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
44*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker
ValidateDivOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)48*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateDivOperator(DelegateData& delegateData,
49*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
50*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
51*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
52*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
55*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("DIV",
58*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
59*89c4ff92SAndroid Build Coastguard Worker IsElementwiseBinarySupported,
60*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
61*89c4ff92SAndroid Build Coastguard Worker isSupported,
62*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
63*89c4ff92SAndroid Build Coastguard Worker inputInfo1,
64*89c4ff92SAndroid Build Coastguard Worker inputInfo2,
65*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
66*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Div);
67*89c4ff92SAndroid Build Coastguard Worker };
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
70*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker
ValidateFloorDivOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)73*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateFloorDivOperator(DelegateData& delegateData,
74*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
75*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
76*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
77*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker // need first to validate that the div operator is supported
80*89c4ff92SAndroid Build Coastguard Worker // then that the floor operator is supported
81*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus status = ValidateDivOperator(delegateData, tfLiteContext, inputInfo1, inputInfo2, outputInfo);
82*89c4ff92SAndroid Build Coastguard Worker if (status != kTfLiteOk)
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker return status;
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker // if the inputs and output of the div are all Signed32 we don't need to add the floor operator afterward.
87*89c4ff92SAndroid Build Coastguard Worker if (AreAllSigned32(inputInfo1, inputInfo2, outputInfo))
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker return status;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker // in case broadcasting is being done from one of the inputs to the div
92*89c4ff92SAndroid Build Coastguard Worker // choose the full sized input tensor to pass to the floor validation routine
93*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo floorInputInfo = inputInfo1;
94*89c4ff92SAndroid Build Coastguard Worker if (inputInfo1.GetNumDimensions() < inputInfo2.GetNumDimensions())
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker floorInputInfo = inputInfo2;
97*89c4ff92SAndroid Build Coastguard Worker }
98*89c4ff92SAndroid Build Coastguard Worker status = ValidateFloorOperator(delegateData, tfLiteContext, floorInputInfo, outputInfo);
99*89c4ff92SAndroid Build Coastguard Worker return status;
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker
ValidateMaximumOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)102*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateMaximumOperator(DelegateData& delegateData,
103*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
104*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
105*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
106*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
109*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("MAXIMUM",
112*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
113*89c4ff92SAndroid Build Coastguard Worker IsElementwiseBinarySupported,
114*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
115*89c4ff92SAndroid Build Coastguard Worker isSupported,
116*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
117*89c4ff92SAndroid Build Coastguard Worker inputInfo1,
118*89c4ff92SAndroid Build Coastguard Worker inputInfo2,
119*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
120*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Maximum);
121*89c4ff92SAndroid Build Coastguard Worker };
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
124*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker
ValidateMinimumOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)127*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateMinimumOperator(DelegateData& delegateData,
128*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
129*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
130*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
131*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
134*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("MINIMUM",
137*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
138*89c4ff92SAndroid Build Coastguard Worker IsElementwiseBinarySupported,
139*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
140*89c4ff92SAndroid Build Coastguard Worker isSupported,
141*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
142*89c4ff92SAndroid Build Coastguard Worker inputInfo1,
143*89c4ff92SAndroid Build Coastguard Worker inputInfo2,
144*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
145*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Minimum);
146*89c4ff92SAndroid Build Coastguard Worker };
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
149*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
ValidateMulOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)152*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateMulOperator(DelegateData& delegateData,
153*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
154*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
155*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
156*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
159*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("MUL",
162*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
163*89c4ff92SAndroid Build Coastguard Worker IsElementwiseBinarySupported,
164*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
165*89c4ff92SAndroid Build Coastguard Worker isSupported,
166*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
167*89c4ff92SAndroid Build Coastguard Worker inputInfo1,
168*89c4ff92SAndroid Build Coastguard Worker inputInfo2,
169*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
170*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Mul);
171*89c4ff92SAndroid Build Coastguard Worker };
172*89c4ff92SAndroid Build Coastguard Worker
173*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
174*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker
ValidateSubOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo1,const armnn::TensorInfo & inputInfo2,const armnn::TensorInfo & outputInfo)177*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateSubOperator(DelegateData& delegateData,
178*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
179*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1,
180*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo2,
181*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
182*89c4ff92SAndroid Build Coastguard Worker {
183*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
184*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
185*89c4ff92SAndroid Build Coastguard Worker {
186*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("SUB",
187*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
188*89c4ff92SAndroid Build Coastguard Worker IsElementwiseBinarySupported,
189*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
190*89c4ff92SAndroid Build Coastguard Worker isSupported,
191*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
192*89c4ff92SAndroid Build Coastguard Worker inputInfo1,
193*89c4ff92SAndroid Build Coastguard Worker inputInfo2,
194*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
195*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Sub);
196*89c4ff92SAndroid Build Coastguard Worker };
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
199*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker
AddFloorDivLayer(DelegateData & delegateData,const armnn::TensorInfo & outputTensorInfo)202*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::IConnectableLayer*, armnn::IConnectableLayer*> AddFloorDivLayer(
203*89c4ff92SAndroid Build Coastguard Worker DelegateData& delegateData,
204*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* divisionLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
207*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Div);
208*89c4ff92SAndroid Build Coastguard Worker // if the output of the div is Signed32 the Floor layer is not required
209*89c4ff92SAndroid Build Coastguard Worker if (armnn::DataType::Signed32 == outputTensorInfo.GetDataType())
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(divisionLayer, divisionLayer);
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = divisionLayer->GetOutputSlot(0);
214*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
215*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* floorLayer = delegateData.m_Network->AddFloorLayer();
216*89c4ff92SAndroid Build Coastguard Worker outputSlot.Connect(floorLayer->GetInputSlot(0));
217*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(divisionLayer, floorLayer);
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
VisitElementwiseBinaryOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t elementwiseBinaryOperatorCode)220*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitElementwiseBinaryOperator(DelegateData& delegateData,
221*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
222*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
223*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
224*89c4ff92SAndroid Build Coastguard Worker int32_t elementwiseBinaryOperatorCode)
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
227*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
228*89c4ff92SAndroid Build Coastguard Worker
229*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
230*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
231*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor0))
232*89c4ff92SAndroid Build Coastguard Worker {
233*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
234*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
235*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
236*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryOperatorCode, nodeIndex);
237*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
238*89c4ff92SAndroid Build Coastguard Worker }
239*89c4ff92SAndroid Build Coastguard Worker
240*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
241*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor1))
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
244*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
245*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
246*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryOperatorCode, nodeIndex);
247*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
251*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
252*89c4ff92SAndroid Build Coastguard Worker {
253*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
254*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
255*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
256*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryOperatorCode, nodeIndex);
257*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker
260*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
261*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
262*89c4ff92SAndroid Build Coastguard Worker
263*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
264*89c4ff92SAndroid Build Coastguard Worker
265*89c4ff92SAndroid Build Coastguard Worker // Check if we need to expand the dims of the input tensor infos.
266*89c4ff92SAndroid Build Coastguard Worker // This is required for a few of the backends.
267*89c4ff92SAndroid Build Coastguard Worker if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
270*89c4ff92SAndroid Build Coastguard Worker }
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker auto* tfLiteNodeParameters = reinterpret_cast<TfLiteAddParams*>(tfLiteNode->builtin_data);
273*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType = kTfLiteActNone;
274*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNodeParameters)
275*89c4ff92SAndroid Build Coastguard Worker {
276*89c4ff92SAndroid Build Coastguard Worker activationType = tfLiteNodeParameters->activation;
277*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
278*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
279*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
280*89c4ff92SAndroid Build Coastguard Worker {
281*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
282*89c4ff92SAndroid Build Coastguard Worker }
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker switch(elementwiseBinaryOperatorCode)
288*89c4ff92SAndroid Build Coastguard Worker {
289*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinAdd:
290*89c4ff92SAndroid Build Coastguard Worker return ValidateAddOperator(delegateData,
291*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
292*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
293*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
294*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
295*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDiv:
296*89c4ff92SAndroid Build Coastguard Worker return ValidateDivOperator(delegateData,
297*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
298*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
299*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
300*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
301*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinFloorDiv:
302*89c4ff92SAndroid Build Coastguard Worker return ValidateFloorDivOperator(delegateData,
303*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
304*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
305*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
306*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
307*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMaximum:
308*89c4ff92SAndroid Build Coastguard Worker return ValidateMaximumOperator(delegateData,
309*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
310*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
311*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
312*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
313*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMinimum:
314*89c4ff92SAndroid Build Coastguard Worker return ValidateMinimumOperator(delegateData,
315*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
316*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
317*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
318*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
319*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMul:
320*89c4ff92SAndroid Build Coastguard Worker return ValidateMulOperator(delegateData,
321*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
322*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
323*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
324*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
325*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSub:
326*89c4ff92SAndroid Build Coastguard Worker return ValidateSubOperator(delegateData,
327*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
328*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo0,
329*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo1,
330*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo);
331*89c4ff92SAndroid Build Coastguard Worker default:
332*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
333*89c4ff92SAndroid Build Coastguard Worker }
334*89c4ff92SAndroid Build Coastguard Worker }
335*89c4ff92SAndroid Build Coastguard Worker
336*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* elementwiseBinaryLayer = nullptr;
337*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade multiLayer;
338*89c4ff92SAndroid Build Coastguard Worker switch(elementwiseBinaryOperatorCode)
339*89c4ff92SAndroid Build Coastguard Worker {
340*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinAdd:
341*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
342*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Add);
343*89c4ff92SAndroid Build Coastguard Worker break;
344*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDiv:
345*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
346*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Div);
347*89c4ff92SAndroid Build Coastguard Worker break;
348*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinFloorDiv:
349*89c4ff92SAndroid Build Coastguard Worker {
350*89c4ff92SAndroid Build Coastguard Worker auto layers = AddFloorDivLayer(delegateData, outputTensorInfo);
351*89c4ff92SAndroid Build Coastguard Worker multiLayer.AssignValues(layers.first, layers.second);
352*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = &multiLayer;
353*89c4ff92SAndroid Build Coastguard Worker }
354*89c4ff92SAndroid Build Coastguard Worker break;
355*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMaximum:
356*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
357*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Maximum);
358*89c4ff92SAndroid Build Coastguard Worker break;
359*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMinimum:
360*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
361*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Minimum);
362*89c4ff92SAndroid Build Coastguard Worker break;
363*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMul:
364*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
365*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Mul);
366*89c4ff92SAndroid Build Coastguard Worker break;
367*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSub:
368*89c4ff92SAndroid Build Coastguard Worker elementwiseBinaryLayer = delegateData.m_Network->AddElementwiseBinaryLayer(
369*89c4ff92SAndroid Build Coastguard Worker armnn::BinaryOperation::Sub);
370*89c4ff92SAndroid Build Coastguard Worker break;
371*89c4ff92SAndroid Build Coastguard Worker default:
372*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
373*89c4ff92SAndroid Build Coastguard Worker }
374*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(elementwiseBinaryLayer != nullptr);
375*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = elementwiseBinaryLayer->GetOutputSlot(0);
376*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
377*89c4ff92SAndroid Build Coastguard Worker
378*89c4ff92SAndroid Build Coastguard Worker auto inputsTensorsProcess = ProcessInputs(elementwiseBinaryLayer,
379*89c4ff92SAndroid Build Coastguard Worker delegateData,
380*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
381*89c4ff92SAndroid Build Coastguard Worker tfLiteNode);
382*89c4ff92SAndroid Build Coastguard Worker if (inputsTensorsProcess == kTfLiteError)
383*89c4ff92SAndroid Build Coastguard Worker {
384*89c4ff92SAndroid Build Coastguard Worker return inputsTensorsProcess;
385*89c4ff92SAndroid Build Coastguard Worker }
386*89c4ff92SAndroid Build Coastguard Worker
387*89c4ff92SAndroid Build Coastguard Worker if(Connect(elementwiseBinaryLayer, tfLiteNode, delegateData) != kTfLiteOk)
388*89c4ff92SAndroid Build Coastguard Worker {
389*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
390*89c4ff92SAndroid Build Coastguard Worker }
391*89c4ff92SAndroid Build Coastguard Worker
392*89c4ff92SAndroid Build Coastguard Worker if (!tfLiteNodeParameters)
393*89c4ff92SAndroid Build Coastguard Worker {
394*89c4ff92SAndroid Build Coastguard Worker // No Activation
395*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker // Check and Create Activation
398*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, elementwiseBinaryLayer, 0, delegateData);
399*89c4ff92SAndroid Build Coastguard Worker }
400*89c4ff92SAndroid Build Coastguard Worker
401*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
402