xref: /aosp_15_r20/external/armnn/delegate/classic/src/Fill.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/utility/IgnoreUnused.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
VisitFillOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteFillOperatorCode)18 TfLiteStatus VisitFillOperator(DelegateData& delegateData,
19                                TfLiteContext* tfLiteContext,
20                                TfLiteNode* tfLiteNode,
21                                int nodeIndex,
22                                int32_t tfLiteFillOperatorCode)
23 {
24     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 
26     switch(tfLiteFillOperatorCode)
27     {
28         case kTfLiteBuiltinFill:
29             TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
30             break;
31         default:
32             return kTfLiteError;
33     }
34 
35     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
36     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
37     if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteFillOperatorCode, nodeIndex))
38     {
39         return kTfLiteError;
40     }
41 
42     const TfLiteTensor& tfLiteFillTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
43     if (!IsValid(tfLiteContext, tfLiteFillTensor, tfLiteFillOperatorCode, nodeIndex))
44     {
45         return kTfLiteError;
46     }
47 
48     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
49     if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteFillOperatorCode, nodeIndex))
50     {
51         return kTfLiteError;
52     }
53 
54     armnn::TensorInfo inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
55     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
56 
57     armnn::FillDescriptor descriptor;
58     switch (tfLiteFillTensor.type)
59     {
60         case kTfLiteFloat32:
61             descriptor.m_Value = tflite::GetTensorData<float>(&tfLiteFillTensor)[0];
62             break;
63         case kTfLiteInt32:
64             descriptor.m_Value = tflite::GetTensorData<int32_t>(&tfLiteFillTensor)[0];
65             break;
66         default:
67             TF_LITE_MAYBE_KERNEL_LOG(
68                 tfLiteContext,
69                 "TfLiteArmnnDelegate: FILL value data type is not supported in operator #%d node #%d: ",
70                 tfLiteFillOperatorCode, nodeIndex);
71             return kTfLiteError;
72     }
73 
74     bool isSupported = false;
75     armnn::BackendId setBackend;
76     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
77     {
78         FORWARD_LAYER_SUPPORT_FUNC("FILL",
79                                    tfLiteContext,
80                                    IsFillSupported,
81                                    delegateData.m_Backends,
82                                    isSupported,
83                                    setBackend,
84                                    inputTensorInfo,
85                                    outInfo,
86                                    descriptor);
87     };
88 
89     if (!delegateData.m_Network)
90     {
91         validateFunc(outputTensorInfo, isSupported);
92         return isSupported ? kTfLiteOk : kTfLiteError;
93     }
94 
95     armnn::IConnectableLayer* layer = delegateData.m_Network->AddFillLayer(descriptor);
96     layer->SetBackendId(setBackend);
97     ARMNN_ASSERT(layer != nullptr);
98 
99     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
100     outputSlot.SetTensorInfo(outputTensorInfo);
101 
102     auto inputsTensorsProcess = ProcessInputs(layer,
103                                               delegateData,
104                                               tfLiteContext,
105                                               tfLiteNode);
106     if (inputsTensorsProcess == kTfLiteError)
107     {
108         return inputsTensorsProcess;
109     }
110 
111     return Connect(layer, tfLiteNode, delegateData);
112 }
113 
114 } // namespace armnnDelegate
115