xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/unique.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stddef.h>
17 #include <stdint.h>
18 
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace unique {
34 
Init(TfLiteContext * context,const char * buffer,size_t length)35 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
36   return nullptr;
37 }
38 
Free(TfLiteContext * context,void * buffer)39 void Free(TfLiteContext* context, void* buffer) {}
40 
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42   static const int kOutputUniqueTensor = 0;
43   static const int kOutputIndexTensor = 1;
44 
45   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
47   const TfLiteTensor* input;
48   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
49   TfLiteTensor* output_unique_tensor;
50   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor,
51                                            &output_unique_tensor));
52   TfLiteTensor* output_index_tensor;
53   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor,
54                                            &output_index_tensor));
55 
56   // The op only supports 1D input.
57   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
58   TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims);
59   // The unique values are determined during evaluation, so we don't know yet
60   // the size of the output tensor.
61   SetTensorToDynamic(output_unique_tensor);
62   return context->ResizeTensor(context, output_index_tensor,
63                                output_index_shape);
64 }
65 
66 namespace {
67 
68 // Actual evaluation for the unique op.
69 template <typename T, typename I>
EvalImpl(TfLiteContext * context,const TfLiteTensor * input,TfLiteNode * node)70 TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
71                       TfLiteNode* node) {
72   // Map from value, to index in the unique elements vector.
73   // Note that we prefer to use map than unordered_map as it showed less
74   // increase in the binary size.
75   std::map<T, int> unique_values;
76   TfLiteTensor* output_indexes;
77   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes));
78   std::vector<T> output_values;
79   I* indexes = GetTensorData<I>(output_indexes);
80   const T* data = GetTensorData<T>(input);
81   const int num_elements = NumElements(input);
82 
83   for (int i = 0; i < num_elements; ++i) {
84     const auto element_it = unique_values.find(data[i]);
85     if (element_it != unique_values.end()) {
86       indexes[i] = element_it->second;
87     } else {
88       const int unique_index = unique_values.size();
89       unique_values[data[i]] = unique_index;
90       indexes[i] = unique_index;
91       output_values.push_back(data[i]);
92     }
93   }
94   // Allocate output tensor.
95   TfLiteTensor* unique_output;
96   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output));
97   std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
98       TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
99   shape->data[0] = unique_values.size();
100   TF_LITE_ENSURE_STATUS(
101       context->ResizeTensor(context, unique_output, shape.release()));
102   // Set the values in the output tensor.
103   T* output_unique_values = GetTensorData<T>(unique_output);
104   for (int i = 0; i < output_values.size(); ++i) {
105     output_unique_values[i] = output_values[i];
106   }
107   return kTfLiteOk;
108 }
109 
110 template <typename T>
EvalImpl(TfLiteContext * context,const TfLiteTensor * input,TfLiteNode * node)111 TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
112                       TfLiteNode* node) {
113   auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data);
114   if (params == nullptr) {
115     TF_LITE_KERNEL_LOG(context, "Null params passed");
116     return kTfLiteError;
117   }
118   switch (params->index_out_type) {
119     case kTfLiteInt32:
120       return EvalImpl<T, int32_t>(context, input, node);
121     case kTfLiteInt64:
122       return EvalImpl<T, int64_t>(context, input, node);
123     default:
124       TF_LITE_KERNEL_LOG(
125           context,
126           "Unique index output array can only be Int32 or In64, requested: %s",
127           TfLiteTypeGetName(params->index_out_type));
128   }
129   return kTfLiteError;
130 }
131 
132 }  // namespace
133 
Eval(TfLiteContext * context,TfLiteNode * node)134 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
135   const TfLiteTensor* input;
136   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
137   TfLiteTensor* output_index_tensor;
138   TF_LITE_ENSURE_OK(context,
139                     GetOutputSafe(context, node, 1, &output_index_tensor));
140   TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
141                     NumElements(input));
142 
143   switch (input->type) {
144     case kTfLiteInt8:
145       TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node));
146       break;
147     case kTfLiteInt16:
148       TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node));
149       break;
150     case kTfLiteInt32:
151       TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node));
152       break;
153     case kTfLiteInt64:
154       TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node));
155       break;
156     case kTfLiteFloat32:
157       TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node));
158       break;
159     case kTfLiteUInt8:
160       TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node));
161       break;
162     default:
163       TF_LITE_KERNEL_LOG(context, "Currently Unique doesn't support type: %s",
164                          TfLiteTypeGetName(input->type));
165       return kTfLiteError;
166   }
167   return kTfLiteOk;
168 }
169 
170 }  // namespace unique
171 
Register_UNIQUE()172 TfLiteRegistration* Register_UNIQUE() {
173   static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare,
174                                  unique::Eval};
175   return &r;
176 }
177 
178 }  // namespace builtin
179 }  // namespace ops
180 }  // namespace tflite
181