1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/c_api_types.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace gather_nd {
29 constexpr int kParams = 0;
30 constexpr int kIndices = 1;
31 constexpr int kOutputTensor = 0;
32
Prepare(TfLiteContext * context,TfLiteNode * node)33 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36
37 const TfLiteTensor* params;
38 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
39 const TfLiteTensor* indices;
40 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
41 TfLiteTensor* output;
42 TF_LITE_ENSURE_OK(context,
43 GetOutputSafe(context, node, kOutputTensor, &output));
44
45 switch (params->type) {
46 case kTfLiteFloat32:
47 case kTfLiteUInt8:
48 case kTfLiteInt8:
49 case kTfLiteInt16:
50 case kTfLiteInt64:
51 case kTfLiteInt32:
52 case kTfLiteString:
53 break;
54 default:
55 TF_LITE_KERNEL_LOG(context,
56 "Params of type '%s' are not supported by gather_nd.",
57 TfLiteTypeGetName(params->type));
58 return kTfLiteError;
59 }
60 switch (indices->type) {
61 case kTfLiteInt64:
62 case kTfLiteInt32:
63 break;
64 default:
65 TF_LITE_KERNEL_LOG(context,
66 "Indices of type '%s' are not supported by gather_nd.",
67 TfLiteTypeGetName(indices->type));
68 return kTfLiteError;
69 }
70
71 const int params_rank = NumDimensions(params);
72 const int indices_rank = NumDimensions(indices);
73 const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
74 if (params_rank < 1) {
75 TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
76 return kTfLiteError;
77 }
78 if (indices_rank < 1) {
79 TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
80 return kTfLiteError;
81 }
82 if (indices_nd > params_rank) {
83 TF_LITE_KERNEL_LOG(
84 context, "Index innermost dimension length must be <= params rank.");
85 return kTfLiteError;
86 }
87
88 // Assign to output the input type.
89 output->type = params->type;
90
91 // The result shape is
92 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
93 const int output_rank = indices_rank + params_rank - indices_nd - 1;
94 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
95 int output_index = 0;
96 for (int i = 0; i < indices_rank - 1; ++i) {
97 output_shape->data[output_index++] = indices->dims->data[i];
98 }
99 for (int i = indices_nd; i < params_rank; ++i) {
100 output_shape->data[output_index++] = params->dims->data[i];
101 }
102 return context->ResizeTensor(context, output, output_shape);
103 }
104
105 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)106 TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
107 TfLiteTensor* output) {
108 return reference_ops::GatherNd(
109 GetTensorShape(params), GetTensorData<ParamsT>(params),
110 GetTensorShape(indices), GetTensorData<IndicesT>(indices),
111 GetTensorShape(output), GetTensorData<ParamsT>(output));
112 }
113
114 template <typename IndicesT>
GatherNdString(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)115 TfLiteStatus GatherNdString(const TfLiteTensor* params,
116 const TfLiteTensor* indices, TfLiteTensor* output) {
117 return reference_ops::GatherNdString(
118 GetTensorShape(params), params, GetTensorShape(indices),
119 GetTensorData<IndicesT>(indices), GetTensorShape(output), output);
120 }
121
122 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)123 TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
124 const TfLiteTensor* indices, TfLiteTensor* output) {
125 bool indices_has_only_positive_elements = true;
126 const auto* indices_values = GetTensorData<IndicesT>(indices);
127 const size_t num_indices = indices->bytes / sizeof(IndicesT);
128 for (size_t i = 0; i < num_indices; i++) {
129 if (indices_values[i] < 0) {
130 indices_has_only_positive_elements = false;
131 break;
132 }
133 }
134 TF_LITE_ENSURE(context, indices_has_only_positive_elements);
135
136 TfLiteStatus status = kTfLiteError;
137 switch (params->type) {
138 case kTfLiteFloat32:
139 status = GatherNd<float, IndicesT>(params, indices, output);
140 break;
141 case kTfLiteUInt8:
142 status = GatherNd<uint8_t, IndicesT>(params, indices, output);
143 break;
144 case kTfLiteInt8:
145 status = GatherNd<int8_t, IndicesT>(params, indices, output);
146 break;
147 case kTfLiteInt16:
148 status = GatherNd<int16_t, IndicesT>(params, indices, output);
149 break;
150 case kTfLiteInt32:
151 status = GatherNd<int32_t, IndicesT>(params, indices, output);
152 break;
153 case kTfLiteInt64:
154 status = GatherNd<int64_t, IndicesT>(params, indices, output);
155 break;
156 case kTfLiteString:
157 status = GatherNdString<IndicesT>(params, indices, output);
158 break;
159 default:
160 TF_LITE_KERNEL_LOG(context,
161 "Params type '%s' are not supported by gather_nd.",
162 TfLiteTypeGetName(params->type));
163 return kTfLiteError;
164 }
165 if (status != kTfLiteOk) {
166 TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
167 }
168 return status;
169 }
170
Eval(TfLiteContext * context,TfLiteNode * node)171 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
172 const TfLiteTensor* params;
173 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
174 const TfLiteTensor* indices;
175 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
176 TfLiteTensor* output;
177 TF_LITE_ENSURE_OK(context,
178 GetOutputSafe(context, node, kOutputTensor, &output));
179
180 // Prevent division by 0 in the helper.
181 // In TF, GatherND supports empty `params` only when `indices` is also empty.
182 TF_LITE_ENSURE(context,
183 (NumElements(params) == 0 && NumElements(indices) == 0) ||
184 NumElements(params) > 0);
185
186 switch (indices->type) {
187 case kTfLiteInt32:
188 return EvalGatherNd<int32_t>(context, params, indices, output);
189 case kTfLiteInt64:
190 return EvalGatherNd<int64_t>(context, params, indices, output);
191 default:
192 TF_LITE_KERNEL_LOG(context,
193 "Indices of type '%s' are not supported by gather_nd.",
194 TfLiteTypeGetName(indices->type));
195 return kTfLiteError;
196 }
197 }
198 } // namespace gather_nd
199
Register_GATHER_ND()200 TfLiteRegistration* Register_GATHER_ND() {
201 static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
202 gather_nd::Prepare, gather_nd::Eval};
203 return &r;
204 }
205 } // namespace builtin
206 } // namespace ops
207 } // namespace tflite
208