xref: /aosp_15_r20/external/armnn/delegate/classic/src/StridedSlice.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/utility/IgnoreUnused.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
VisitStridedSliceOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t sliceOperatorCode)18 TfLiteStatus VisitStridedSliceOperator(DelegateData& delegateData,
19                                        TfLiteContext* tfLiteContext,
20                                        TfLiteNode* tfLiteNode,
21                                        int nodeIndex,
22                                        int32_t sliceOperatorCode)
23 {
24     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 4, nodeIndex));
25     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26 
27     // Read inputs [input, begin, end, strides]
28     int numInputs = tfLiteNode->inputs->size;
29     std::vector<const TfLiteTensor*> tfLiteInputs;
30     tfLiteInputs.reserve(numInputs);
31     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
32     for (int i = 0; i < numInputs; i++)
33     {
34         const TfLiteTensor* inputTensor = &tfLiteTensors[tfLiteNode->inputs->data[i]];
35         tfLiteInputs.push_back(inputTensor);
36         if (!IsValid(tfLiteContext, *inputTensor, sliceOperatorCode, nodeIndex))
37         {
38             return kTfLiteError;
39         }
40     }
41 
42     // We save the begin, end and strides tensors in our descriptor. Therefore we have to read those values from inputs
43     int inputRank = tfLiteInputs[0]->dims->size;
44     auto ReadInt32Input = [&](int inputIndex, std::vector<int32_t>& outputData) ->  TfLiteStatus
45     {
46         if (tfLiteInputs[inputIndex]->type != kTfLiteInt32)
47         {
48             TF_LITE_MAYBE_KERNEL_LOG(
49                 tfLiteContext,
50                 "TfLiteArmnnDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need to "
51                 "be of type int32. Operator: #%d node #%d: ",
52                 sliceOperatorCode, nodeIndex);
53             return kTfLiteError;
54         }
55         int rank = tfLiteInputs[inputIndex]->dims->size;
56         if (rank != 1)
57         {
58             TF_LITE_MAYBE_KERNEL_LOG(
59                 tfLiteContext,
60                 "TfLiteArmnnDelegate: The Begin-, End- and Stride-Tensors of the StridedSlice operation need to "
61                 "be a 1D-Tensor. Operator: #%d node #%d: ",
62                 sliceOperatorCode, nodeIndex);
63             return kTfLiteError;
64         }
65         int numValues = tfLiteInputs[inputIndex]->dims->data[0];
66         if (numValues != inputRank)
67         {
68             TF_LITE_MAYBE_KERNEL_LOG(
69                 tfLiteContext,
70                 "TfLiteArmnnDelegate: The number of values in the Begin-, End- and Stride-Tensors of the "
71                 "StridedSlice operation need to be equal to the rank of the Input-Tensor. Operator: #%d node #%d: ",
72                 sliceOperatorCode, nodeIndex);
73             return kTfLiteError;
74         }
75         // return tensor data
76         auto* tensorDataPtr = tflite::GetTensorData<int32_t>(tfLiteInputs[inputIndex]);
77         outputData.assign(tensorDataPtr, tensorDataPtr+numValues);
78         return kTfLiteOk;
79     };
80 
81     std::vector<int32_t> beginData;
82     if (ReadInt32Input(1, beginData) != kTfLiteOk)
83         return kTfLiteError;
84     std::vector<int32_t> endData;
85     if (ReadInt32Input(2, endData) != kTfLiteOk)
86         return kTfLiteError;
87     std::vector<int32_t> strideData;
88     if (ReadInt32Input(3, strideData) != kTfLiteOk)
89         return kTfLiteError;
90 
91     // parse built in options
92     auto* stridedSliceParams = reinterpret_cast<TfLiteStridedSliceParams*>(tfLiteNode->builtin_data);
93 
94     // Write all data to the descriptor
95     armnn::StridedSliceDescriptor descriptor;
96     descriptor.m_Begin          = std::move(beginData);
97     descriptor.m_End            = std::move(endData);
98     descriptor.m_Stride         = std::move(strideData);
99     descriptor.m_BeginMask      = stridedSliceParams->begin_mask;
100     descriptor.m_EllipsisMask   = stridedSliceParams->ellipsis_mask;
101     descriptor.m_EndMask        = stridedSliceParams->end_mask;
102     descriptor.m_NewAxisMask    = stridedSliceParams->new_axis_mask;
103     descriptor.m_ShrinkAxisMask = stridedSliceParams->shrink_axis_mask;
104     descriptor.m_DataLayout     = armnn::DataLayout::NHWC;
105 
106     // Validate output
107     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
108     if (!IsValid(tfLiteContext, tfLiteOutputTensor, sliceOperatorCode, nodeIndex))
109     {
110         return kTfLiteError;
111     }
112 
113     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(*tfLiteInputs[0]);
114     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
115 
116     bool isSupported = false;
117     armnn::BackendId setBackend;
118     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
119     {
120         FORWARD_LAYER_SUPPORT_FUNC("STRIDED_SLICE",
121                                    tfLiteContext,
122                                    IsStridedSliceSupported,
123                                    delegateData.m_Backends,
124                                    isSupported,
125                                    setBackend,
126                                    inputTensorInfo,
127                                    outInfo,
128                                    descriptor);
129     };
130 
131     if (!delegateData.m_Network)
132     {
133         validateFunc(outputTensorInfo, isSupported);
134         return isSupported ? kTfLiteOk : kTfLiteError;
135     }
136 
137     // Add a StridedSlice layer
138     armnn::IConnectableLayer* layer = delegateData.m_Network->AddStridedSliceLayer(descriptor);
139     layer->SetBackendId(setBackend);
140     ARMNN_ASSERT(layer != nullptr);
141 
142     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
143     outputSlot.SetTensorInfo(outputTensorInfo);
144 
145     // try to connect the Constant Inputs if there are any
146     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
147     {
148         return kTfLiteError;
149     }
150 
151     // Connect
152     return Connect(layer, tfLiteNode, delegateData);
153 }
154 
155 } // namespace armnnDelegate
156 
157