xref: /aosp_15_r20/external/armnn/delegate/classic/src/GatherNd.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <algorithm>
11 #include <iterator>
12 #include <string>
13 #include <vector>
14 
15 namespace armnnDelegate
16 {
VisitGatherNdOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)17 TfLiteStatus VisitGatherNdOperator(DelegateData& delegateData,
18                                  TfLiteContext* tfLiteContext,
19                                  TfLiteNode* tfLiteNode,
20                                  int nodeIndex,
21                                  int32_t operatorCode)
22 {
23     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 
26     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27 
28     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
30     {
31         return kTfLiteError;
32     }
33 
34     const TfLiteTensor& tfLiteIndicesTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
35     if (!IsValid(tfLiteContext, tfLiteIndicesTensor, operatorCode, nodeIndex))
36     {
37         return kTfLiteError;
38     }
39 
40     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
41     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
42     {
43         return kTfLiteError;
44     }
45 
46     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
47     const armnn::TensorInfo& indicesTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteIndicesTensor);
48     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
49 
50     armnn::BackendId setBackend;
51     if (!delegateData.m_Network)
52     {
53         // Check if supported
54         bool isSupported = false;
55         FORWARD_LAYER_SUPPORT_FUNC("GATHER_ND",
56                                    tfLiteContext,
57                                    IsGatherNdSupported,
58                                    delegateData.m_Backends,
59                                    isSupported,
60                                    setBackend,
61                                    inputTensorInfo,
62                                    indicesTensorInfo,
63                                    outputTensorInfo);
64         return isSupported ? kTfLiteOk : kTfLiteError;
65     }
66 
67     armnn::IConnectableLayer* layer = delegateData.m_Network->AddGatherNdLayer();
68     layer->SetBackendId(setBackend);
69     ARMNN_ASSERT(layer != nullptr);
70     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
71 
72     auto inputsTensorsProcess = ProcessInputs(layer,
73                                               delegateData,
74                                               tfLiteContext,
75                                               tfLiteNode);
76     if (inputsTensorsProcess == kTfLiteError)
77     {
78         return inputsTensorsProcess;
79     }
80 
81     return Connect(layer, tfLiteNode, delegateData);
82 }
83 } // namespace armnnDelegate