xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/gather_nd.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/c_api_types.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace gather_nd {
29 constexpr int kParams = 0;
30 constexpr int kIndices = 1;
31 constexpr int kOutputTensor = 0;
32 
Prepare(TfLiteContext * context,TfLiteNode * node)33 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36 
37   const TfLiteTensor* params;
38   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
39   const TfLiteTensor* indices;
40   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
41   TfLiteTensor* output;
42   TF_LITE_ENSURE_OK(context,
43                     GetOutputSafe(context, node, kOutputTensor, &output));
44 
45   switch (params->type) {
46     case kTfLiteFloat32:
47     case kTfLiteUInt8:
48     case kTfLiteInt8:
49     case kTfLiteInt16:
50     case kTfLiteInt64:
51     case kTfLiteInt32:
52     case kTfLiteString:
53       break;
54     default:
55       TF_LITE_KERNEL_LOG(context,
56                          "Params of type '%s' are not supported by gather_nd.",
57                          TfLiteTypeGetName(params->type));
58       return kTfLiteError;
59   }
60   switch (indices->type) {
61     case kTfLiteInt64:
62     case kTfLiteInt32:
63       break;
64     default:
65       TF_LITE_KERNEL_LOG(context,
66                          "Indices of type '%s' are not supported by gather_nd.",
67                          TfLiteTypeGetName(indices->type));
68       return kTfLiteError;
69   }
70 
71   const int params_rank = NumDimensions(params);
72   const int indices_rank = NumDimensions(indices);
73   const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
74   if (params_rank < 1) {
75     TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
76     return kTfLiteError;
77   }
78   if (indices_rank < 1) {
79     TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
80     return kTfLiteError;
81   }
82   if (indices_nd > params_rank) {
83     TF_LITE_KERNEL_LOG(
84         context, "Index innermost dimension length must be <= params rank.");
85     return kTfLiteError;
86   }
87 
88   // Assign to output the input type.
89   output->type = params->type;
90 
91   // The result shape is
92   // indices.shape[:-1] + params.shape[indices.shape[-1]:]
93   const int output_rank = indices_rank + params_rank - indices_nd - 1;
94   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
95   int output_index = 0;
96   for (int i = 0; i < indices_rank - 1; ++i) {
97     output_shape->data[output_index++] = indices->dims->data[i];
98   }
99   for (int i = indices_nd; i < params_rank; ++i) {
100     output_shape->data[output_index++] = params->dims->data[i];
101   }
102   return context->ResizeTensor(context, output, output_shape);
103 }
104 
105 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)106 TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
107                       TfLiteTensor* output) {
108   return reference_ops::GatherNd(
109       GetTensorShape(params), GetTensorData<ParamsT>(params),
110       GetTensorShape(indices), GetTensorData<IndicesT>(indices),
111       GetTensorShape(output), GetTensorData<ParamsT>(output));
112 }
113 
114 template <typename IndicesT>
GatherNdString(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)115 TfLiteStatus GatherNdString(const TfLiteTensor* params,
116                             const TfLiteTensor* indices, TfLiteTensor* output) {
117   return reference_ops::GatherNdString(
118       GetTensorShape(params), params, GetTensorShape(indices),
119       GetTensorData<IndicesT>(indices), GetTensorShape(output), output);
120 }
121 
122 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)123 TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
124                           const TfLiteTensor* indices, TfLiteTensor* output) {
125   bool indices_has_only_positive_elements = true;
126   const auto* indices_values = GetTensorData<IndicesT>(indices);
127   const size_t num_indices = indices->bytes / sizeof(IndicesT);
128   for (size_t i = 0; i < num_indices; i++) {
129     if (indices_values[i] < 0) {
130       indices_has_only_positive_elements = false;
131       break;
132     }
133   }
134   TF_LITE_ENSURE(context, indices_has_only_positive_elements);
135 
136   TfLiteStatus status = kTfLiteError;
137   switch (params->type) {
138     case kTfLiteFloat32:
139       status = GatherNd<float, IndicesT>(params, indices, output);
140       break;
141     case kTfLiteUInt8:
142       status = GatherNd<uint8_t, IndicesT>(params, indices, output);
143       break;
144     case kTfLiteInt8:
145       status = GatherNd<int8_t, IndicesT>(params, indices, output);
146       break;
147     case kTfLiteInt16:
148       status = GatherNd<int16_t, IndicesT>(params, indices, output);
149       break;
150     case kTfLiteInt32:
151       status = GatherNd<int32_t, IndicesT>(params, indices, output);
152       break;
153     case kTfLiteInt64:
154       status = GatherNd<int64_t, IndicesT>(params, indices, output);
155       break;
156     case kTfLiteString:
157       status = GatherNdString<IndicesT>(params, indices, output);
158       break;
159     default:
160       TF_LITE_KERNEL_LOG(context,
161                          "Params type '%s' are not supported by gather_nd.",
162                          TfLiteTypeGetName(params->type));
163       return kTfLiteError;
164   }
165   if (status != kTfLiteOk) {
166     TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
167   }
168   return status;
169 }
170 
Eval(TfLiteContext * context,TfLiteNode * node)171 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
172   const TfLiteTensor* params;
173   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
174   const TfLiteTensor* indices;
175   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
176   TfLiteTensor* output;
177   TF_LITE_ENSURE_OK(context,
178                     GetOutputSafe(context, node, kOutputTensor, &output));
179 
180   // Prevent division by 0 in the helper.
181   // In TF, GatherND supports empty `params` only when `indices` is also empty.
182   TF_LITE_ENSURE(context,
183                  (NumElements(params) == 0 && NumElements(indices) == 0) ||
184                      NumElements(params) > 0);
185 
186   switch (indices->type) {
187     case kTfLiteInt32:
188       return EvalGatherNd<int32_t>(context, params, indices, output);
189     case kTfLiteInt64:
190       return EvalGatherNd<int64_t>(context, params, indices, output);
191     default:
192       TF_LITE_KERNEL_LOG(context,
193                          "Indices of type '%s' are not supported by gather_nd.",
194                          TfLiteTypeGetName(indices->type));
195       return kTfLiteError;
196   }
197 }
198 }  // namespace gather_nd
199 
Register_GATHER_ND()200 TfLiteRegistration* Register_GATHER_ND() {
201   static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
202                                  gather_nd::Prepare, gather_nd::Eval};
203   return &r;
204 }
205 }  // namespace builtin
206 }  // namespace ops
207 }  // namespace tflite
208