xref: /aosp_15_r20/external/armnn/delegate/classic/src/BatchSpace.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <tensorflow/lite/builtin_ops.h>
9 #include <tensorflow/lite/c/builtin_op_data.h>
10 #include <tensorflow/lite/c/common.h>
11 #include <tensorflow/lite/minimal_logging.h>
12 
13 namespace armnnDelegate
14 {
15 
VisitBatchToSpaceNdOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)16 TfLiteStatus VisitBatchToSpaceNdOperator(DelegateData& delegateData,
17                                          TfLiteContext* tfLiteContext,
18                                          TfLiteNode* tfLiteNode,
19                                          int nodeIndex,
20                                          int32_t operatorCode)
21 {
22     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
23     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
24 
25     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
26     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
27     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
28     {
29         return kTfLiteError;
30     }
31 
32     const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
33     if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
34     {
35         return kTfLiteError;
36     }
37 
38     const TfLiteTensor& tfLiteCropsTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
39     if (!IsValid(tfLiteContext, tfLiteCropsTensor, operatorCode, nodeIndex))
40     {
41         return kTfLiteError;
42     }
43 
44     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
45     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
46     {
47         return kTfLiteError;
48     }
49 
50     const armnn::TensorInfo& inputTensorInfo      = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
51     const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
52     const armnn::TensorInfo& cropsTensorInfo      = GetTensorInfoForTfLiteTensor(tfLiteCropsTensor);
53     const armnn::TensorInfo& outputTensorInfo     = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
54 
55     std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
56     ::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
57 
58     std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements());
59     std::memcpy(cropsVector.data(), tfLiteCropsTensor.data.data, cropsTensorInfo.GetNumBytes());
60 
61     size_t step = 2;
62     std::vector<std::pair<unsigned int, unsigned int>> crops;
63     for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i)
64     {
65         crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]);
66     }
67 
68     armnn::BatchToSpaceNdDescriptor descriptor;
69     descriptor.m_BlockShape = blockShape;
70     descriptor.m_Crops = crops;
71     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
72 
73     // Check if supported
74     bool isSupported = false;
75     armnn::BackendId setBackend;
76     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
77     {
78         FORWARD_LAYER_SUPPORT_FUNC("BATCH_TO_SPACE_ND",
79                                    tfLiteContext,
80                                    IsBatchToSpaceNdSupported,
81                                    delegateData.m_Backends,
82                                    isSupported,
83                                    setBackend,
84                                    inputTensorInfo,
85                                    outputTensorInfo,
86                                    descriptor);
87     };
88 
89     // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
90     // support for the operator
91     // If supported, VisitBatchToSpaceNdOperator will be called again to add the layer to the network as seen below
92     if (!delegateData.m_Network)
93     {
94         validateFunc(outputTensorInfo, isSupported);
95         return isSupported ? kTfLiteOk : kTfLiteError;
96     }
97 
98     // Add a BatchToSpace layer
99     armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchToSpaceNdLayer(descriptor);
100     layer->SetBackendId(setBackend);
101     ARMNN_ASSERT(layer != nullptr);
102 
103     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
104     outputSlot.SetTensorInfo(outputTensorInfo);
105 
106     // try to connect the Constant Inputs if there are any
107     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
108     {
109         return kTfLiteError;
110     }
111 
112     // Connect
113     return Connect(layer, tfLiteNode, delegateData);
114 }
115 
VisitSpaceToBatchNdOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)116 TfLiteStatus VisitSpaceToBatchNdOperator(DelegateData& delegateData,
117                                          TfLiteContext* tfLiteContext,
118                                          TfLiteNode* tfLiteNode,
119                                          int nodeIndex,
120                                          int32_t operatorCode)
121 {
122     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
123     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
124 
125     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
126     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
127     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
128     {
129         return kTfLiteError;
130     }
131 
132     const TfLiteTensor& tfLiteBlockShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
133     if (!IsValid(tfLiteContext, tfLiteBlockShapeTensor, operatorCode, nodeIndex))
134     {
135         return kTfLiteError;
136     }
137 
138     const TfLiteTensor& tfLitePadListTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
139     if (!IsValid(tfLiteContext, tfLitePadListTensor, operatorCode, nodeIndex))
140     {
141         return kTfLiteError;
142     }
143 
144     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
145     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
146     {
147         return kTfLiteError;
148     }
149 
150     const armnn::TensorInfo& inputTensorInfo      = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
151     const armnn::TensorInfo& blockShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBlockShapeTensor);
152     const armnn::TensorInfo& padListTensorInfo    = GetTensorInfoForTfLiteTensor(tfLitePadListTensor);
153     const armnn::TensorInfo& outputTensorInfo     = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
154 
155     std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
156     std::memcpy(blockShape.data(), tfLiteBlockShapeTensor.data.data, blockShapeTensorInfo.GetNumBytes());
157 
158     std::vector<unsigned int> padListVector(padListTensorInfo.GetNumElements());
159     std::memcpy(padListVector.data(), tfLitePadListTensor.data.data, padListTensorInfo.GetNumBytes());
160 
161     size_t step = 2;
162     std::vector<std::pair<unsigned int, unsigned int>> padList;
163     for (unsigned int i = 0; i < padListTensorInfo.GetNumElements() / step; ++i)
164     {
165         padList.emplace_back(padListVector[i * step], padListVector[i * step + 1]);
166     }
167 
168     armnn::SpaceToBatchNdDescriptor descriptor;
169     descriptor.m_BlockShape = blockShape;
170     descriptor.m_PadList = padList;
171     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
172 
173     // Check if supported
174     bool isSupported = false;
175     armnn::BackendId setBackend;
176     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
177     {
178         FORWARD_LAYER_SUPPORT_FUNC("SPACE_TO_BATCH_ND",
179                                    tfLiteContext,
180                                    IsSpaceToBatchNdSupported,
181                                    delegateData.m_Backends,
182                                    isSupported,
183                                    setBackend,
184                                    inputTensorInfo,
185                                    outputTensorInfo,
186                                    descriptor);
187     };
188 
189     // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
190     // support for the operator
191     // If supported, VisitSpaceToBatchNdOperator will be called again to add the layer to the network as seen below
192     if (!delegateData.m_Network)
193     {
194         validateFunc(outputTensorInfo, isSupported);
195         return isSupported ? kTfLiteOk : kTfLiteError;
196     }
197 
198     // Add a SpaceToBatch layer
199     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSpaceToBatchNdLayer(descriptor);
200     layer->SetBackendId(setBackend);
201     ARMNN_ASSERT(layer != nullptr);
202 
203     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
204     outputSlot.SetTensorInfo(outputTensorInfo);
205 
206     // try to connect the Constant Inputs if there are any
207     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
208     {
209         return kTfLiteError;
210     }
211 
212     // Connect
213     return Connect(layer, tfLiteNode, delegateData);
214 }
215 
216 } // namespace armnnDelegate
217