1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include <map>
20 #include <memory>
21 #include <vector>
22
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace unique {
34
Init(TfLiteContext * context,const char * buffer,size_t length)35 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
36 return nullptr;
37 }
38
Free(TfLiteContext * context,void * buffer)39 void Free(TfLiteContext* context, void* buffer) {}
40
Prepare(TfLiteContext * context,TfLiteNode * node)41 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 static const int kOutputUniqueTensor = 0;
43 static const int kOutputIndexTensor = 1;
44
45 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
46 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
47 const TfLiteTensor* input;
48 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
49 TfLiteTensor* output_unique_tensor;
50 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputUniqueTensor,
51 &output_unique_tensor));
52 TfLiteTensor* output_index_tensor;
53 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputIndexTensor,
54 &output_index_tensor));
55
56 // The op only supports 1D input.
57 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
58 TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims);
59 // The unique values are determined during evaluation, so we don't know yet
60 // the size of the output tensor.
61 SetTensorToDynamic(output_unique_tensor);
62 return context->ResizeTensor(context, output_index_tensor,
63 output_index_shape);
64 }
65
66 namespace {
67
68 // Actual evaluation for the unique op.
69 template <typename T, typename I>
EvalImpl(TfLiteContext * context,const TfLiteTensor * input,TfLiteNode * node)70 TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
71 TfLiteNode* node) {
72 // Map from value, to index in the unique elements vector.
73 // Note that we prefer to use map than unordered_map as it showed less
74 // increase in the binary size.
75 std::map<T, int> unique_values;
76 TfLiteTensor* output_indexes;
77 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 1, &output_indexes));
78 std::vector<T> output_values;
79 I* indexes = GetTensorData<I>(output_indexes);
80 const T* data = GetTensorData<T>(input);
81 const int num_elements = NumElements(input);
82
83 for (int i = 0; i < num_elements; ++i) {
84 const auto element_it = unique_values.find(data[i]);
85 if (element_it != unique_values.end()) {
86 indexes[i] = element_it->second;
87 } else {
88 const int unique_index = unique_values.size();
89 unique_values[data[i]] = unique_index;
90 indexes[i] = unique_index;
91 output_values.push_back(data[i]);
92 }
93 }
94 // Allocate output tensor.
95 TfLiteTensor* unique_output;
96 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &unique_output));
97 std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
98 TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
99 shape->data[0] = unique_values.size();
100 TF_LITE_ENSURE_STATUS(
101 context->ResizeTensor(context, unique_output, shape.release()));
102 // Set the values in the output tensor.
103 T* output_unique_values = GetTensorData<T>(unique_output);
104 for (int i = 0; i < output_values.size(); ++i) {
105 output_unique_values[i] = output_values[i];
106 }
107 return kTfLiteOk;
108 }
109
110 template <typename T>
EvalImpl(TfLiteContext * context,const TfLiteTensor * input,TfLiteNode * node)111 TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
112 TfLiteNode* node) {
113 auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data);
114 if (params == nullptr) {
115 TF_LITE_KERNEL_LOG(context, "Null params passed");
116 return kTfLiteError;
117 }
118 switch (params->index_out_type) {
119 case kTfLiteInt32:
120 return EvalImpl<T, int32_t>(context, input, node);
121 case kTfLiteInt64:
122 return EvalImpl<T, int64_t>(context, input, node);
123 default:
124 TF_LITE_KERNEL_LOG(
125 context,
126 "Unique index output array can only be Int32 or In64, requested: %s",
127 TfLiteTypeGetName(params->index_out_type));
128 }
129 return kTfLiteError;
130 }
131
132 } // namespace
133
Eval(TfLiteContext * context,TfLiteNode * node)134 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
135 const TfLiteTensor* input;
136 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
137 TfLiteTensor* output_index_tensor;
138 TF_LITE_ENSURE_OK(context,
139 GetOutputSafe(context, node, 1, &output_index_tensor));
140 TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
141 NumElements(input));
142
143 switch (input->type) {
144 case kTfLiteInt8:
145 TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node));
146 break;
147 case kTfLiteInt16:
148 TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node));
149 break;
150 case kTfLiteInt32:
151 TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node));
152 break;
153 case kTfLiteInt64:
154 TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node));
155 break;
156 case kTfLiteFloat32:
157 TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node));
158 break;
159 case kTfLiteUInt8:
160 TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node));
161 break;
162 default:
163 TF_LITE_KERNEL_LOG(context, "Currently Unique doesn't support type: %s",
164 TfLiteTypeGetName(input->type));
165 return kTfLiteError;
166 }
167 return kTfLiteOk;
168 }
169
170 } // namespace unique
171
Register_UNIQUE()172 TfLiteRegistration* Register_UNIQUE() {
173 static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare,
174 unique::Eval};
175 return &r;
176 }
177
178 } // namespace builtin
179 } // namespace ops
180 } // namespace tflite
181