xref: /aosp_15_r20/external/armnn/delegate/opaque/src/Redefine.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/utility/IgnoreUnused.hpp>
8 
9 #include "OpaqueDelegateUtils.hpp"
10 
11 #include <tensorflow/lite/builtin_ops.h>
12 #include <tensorflow/lite/c/builtin_op_data.h>
13 #include <tensorflow/lite/c/common.h>
14 #include <tensorflow/lite/minimal_logging.h>
15 #include <numeric>
16 
17 namespace armnnOpaqueDelegate
18 {
19 
VisitCastOperator(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,int nodeIndex,int32_t operatorCode)20 TfLiteStatus VisitCastOperator(DelegateData& delegateData,
21                                TfLiteOpaqueContext* tfLiteContext,
22                                TfLiteOpaqueNode* tfLiteNode,
23                                int nodeIndex,
24                                int32_t operatorCode)
25 {
26     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28     int numInputs = 0;
29     const int* inputTensors;
30     if (TfLiteOpaqueNodeInputs(tfLiteNode, &inputTensors, &numInputs) != kTfLiteOk)
31     {
32         return kTfLiteError;
33     }
34 
35     // This layer only has 1 input, so we can directly assign tensor[0] to a new opaque tensor
36     const TfLiteOpaqueTensor*
37           tfLiteInputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputTensors[numInputs-1]);
38     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
39     {
40         return kTfLiteError;
41     }
42 
43     int numOutputs = 0;
44     const int* outputTensors;
45     if (TfLiteOpaqueNodeOutputs(tfLiteNode, &outputTensors, &numOutputs) != kTfLiteOk)
46     {
47         return kTfLiteError;
48     }
49 
50     // This layer only has 1 output, so we can directly assign tensor[0] to a new opaque tensor
51     const TfLiteOpaqueTensor*
52           tfLiteOutputTensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputTensors[numOutputs-1]);
53     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
54     {
55         return kTfLiteError;
56     }
57 
58     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
59     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
60 
61     bool             isSupported  = false;
62     armnn::BackendId setBackend;
63     auto             validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported) {
64         FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("CAST",
65                                    tfLiteContext,
66                                    IsCastSupported,
67                                    delegateData.m_Backends,
68                                    isSupported,
69                                    setBackend,
70                                    inputTensorInfo,
71                                    outInfo);
72     };
73 
74     // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
75     // support for the operator
76     // If supported, VisitCastOperator will be called again to add the layer to the network as seen further below
77     if (!delegateData.m_Network)
78     {
79         validateFunc(outputTensorInfo, isSupported);
80         return isSupported ? kTfLiteOk : kTfLiteError;
81     }
82 
83     // Add a Cast layer
84     armnn::IConnectableLayer* layer = delegateData.m_Network->AddCastLayer();
85     layer->SetBackendId(setBackend);
86     ARMNN_ASSERT(layer != nullptr);
87 
88     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
89     outputSlot.SetTensorInfo(outputTensorInfo);
90 
91     // try to connect the Constant Inputs if there are any
92     if (ProcessInputs(layer, delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk)
93     {
94         return kTfLiteError;
95     }
96 
97     // Connect
98     return Connect(layer, tfLiteContext, tfLiteNode, delegateData);
99 }
100 }
101