1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/utility/IgnoreUnused.hpp>
9
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14
15 namespace armnnDelegate
16 {
17
VisitFillOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteFillOperatorCode)18 TfLiteStatus VisitFillOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t tfLiteFillOperatorCode)
23 {
24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25
26 switch(tfLiteFillOperatorCode)
27 {
28 case kTfLiteBuiltinFill:
29 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
30 break;
31 default:
32 return kTfLiteError;
33 }
34
35 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
36 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
37 if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteFillOperatorCode, nodeIndex))
38 {
39 return kTfLiteError;
40 }
41
42 const TfLiteTensor& tfLiteFillTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
43 if (!IsValid(tfLiteContext, tfLiteFillTensor, tfLiteFillOperatorCode, nodeIndex))
44 {
45 return kTfLiteError;
46 }
47
48 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
49 if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteFillOperatorCode, nodeIndex))
50 {
51 return kTfLiteError;
52 }
53
54 armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
55 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
56
57 armnn::FillDescriptor descriptor;
58 switch (tfLiteFillTensor.type)
59 {
60 case kTfLiteFloat32:
61 descriptor.m_Value = tflite::GetTensorData<float>(&tfLiteFillTensor)[0];
62 break;
63 case kTfLiteInt32:
64 descriptor.m_Value = tflite::GetTensorData<int32_t>(&tfLiteFillTensor)[0];
65 break;
66 default:
67 TF_LITE_MAYBE_KERNEL_LOG(
68 tfLiteContext,
69 "TfLiteArmnnDelegate: FILL value data type is not supported in operator #%d node #%d: ",
70 tfLiteFillOperatorCode, nodeIndex);
71 return kTfLiteError;
72 }
73
74 bool isSupported = false;
75 armnn::BackendId setBackend;
76 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
77 {
78 FORWARD_LAYER_SUPPORT_FUNC("FILL",
79 tfLiteContext,
80 IsFillSupported,
81 delegateData.m_Backends,
82 isSupported,
83 setBackend,
84 inputTensorInfo,
85 outInfo,
86 descriptor);
87 };
88
89 if (!delegateData.m_Network)
90 {
91 validateFunc(outputTensorInfo, isSupported);
92 return isSupported ? kTfLiteOk : kTfLiteError;
93 }
94
95 armnn::IConnectableLayer* layer = delegateData.m_Network->AddFillLayer(descriptor);
96 layer->SetBackendId(setBackend);
97 ARMNN_ASSERT(layer != nullptr);
98
99 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
100 outputSlot.SetTensorInfo(outputTensorInfo);
101
102 auto inputsTensorsProcess = ProcessInputs(layer,
103 delegateData,
104 tfLiteContext,
105 tfLiteNode);
106 if (inputsTensorsProcess == kTfLiteError)
107 {
108 return inputsTensorsProcess;
109 }
110
111 return Connect(layer, tfLiteNode, delegateData);
112 }
113
114 } // namespace armnnDelegate
115