1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "SharedFunctions.hpp"
9
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14
15 namespace armnnDelegate
16 {
17
VisitFloorOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)18 TfLiteStatus VisitFloorOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t operatorCode)
23 {
24 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30 {
31 return kTfLiteError;
32 }
33
34 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
35 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
36 {
37 return kTfLiteError;
38 }
39
40 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
41 // NOTE: looks like the outputTensorInfo is the only thing that is required for the case
42 // where we are adding the floor layer so maybe move the other stuff inside the
43 // if !delegateData block for efficiency.
44 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
45
46 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
47 // support for the operator
48 // If supported, VisitFloorOperator will be called again to add the layer to the network as seen further below
49 if (!delegateData.m_Network)
50 {
51 return ValidateFloorOperator(delegateData, tfLiteContext, inputTensorInfo, outputTensorInfo);
52 }
53
54 // Add a Floor layer
55 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFloorLayer();
56 ARMNN_ASSERT(layer != nullptr);
57
58 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
59 outputSlot.SetTensorInfo(outputTensorInfo);
60
61 // try to connect the Constant Inputs if there are any
62 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
63 {
64 return kTfLiteError;
65 }
66
67 // Connect
68 return Connect(layer, tfLiteNode, delegateData);
69 }
70
71 } // namespace armnnDelegate
72