xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/topk_v2.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <algorithm>
18 #include <iterator>
19 #include <vector>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace topk_v2 {
31 constexpr int kInputTensor = 0;
32 constexpr int kInputTopK = 1;
33 constexpr int kOutputValues = 0;
34 constexpr int kOutputIndexes = 1;
35 
36 namespace {
ResizeOutput(TfLiteContext * context,TfLiteNode * node)37 TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
38   const TfLiteTensor* top_k;
39   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
40   // INT32 number of top results is supported.
41   TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
42   // Check that the tensor contains only one value.
43   TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1);
44   const int32 k = *GetTensorData<int32_t>(top_k);
45 
46   const TfLiteTensor* input;
47   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
48   const int num_dimensions = NumDimensions(input);
49   // Check that input has one or more dimensions.
50   TF_LITE_ENSURE_MSG(context, input->dims->size >= 1,
51                      "TopK k input must have 1 or more dimensions.");
52   // Check that k is less or equal the internal dimension.
53   TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1],
54                      "TopK k is higher than the internal dimension.");
55 
56   TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions);
57   TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions);
58   for (int i = 0; i < num_dimensions - 1; ++i) {
59     output_indexes_shape->data[i] = input->dims->data[i];
60     output_values_shape->data[i] = input->dims->data[i];
61   }
62   output_indexes_shape->data[num_dimensions - 1] = k;
63   output_values_shape->data[num_dimensions - 1] = k;
64   TfLiteTensor* output_indexes;
65   TF_LITE_ENSURE_OK(
66       context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
67   TfLiteTensor* output_values;
68   TF_LITE_ENSURE_OK(
69       context, GetOutputSafe(context, node, kOutputValues, &output_values));
70   // Force output types.
71   output_indexes->type = kTfLiteInt32;
72   output_values->type = input->type;
73   auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
74                                  TfLiteIntArray* delete_on_error) {
75     TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
76     if (status != kTfLiteOk) {
77       if (delete_on_error != nullptr) {
78         TfLiteIntArrayFree(delete_on_error);
79       }
80     }
81     return status;
82   };
83   TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape,
84                                            output_values_shape));
85   TF_LITE_ENSURE_OK(context,
86                     resize_tensor(output_values, output_values_shape, nullptr));
87   return kTfLiteOk;
88 }
89 
90 // Class that collects indices of top k values.  Based on template
91 // tensorflow::gtl::TopN<> but, for optimization, it re-uses the same container.
92 template <typename T>
93 class TopContainer {
94  public:
95   TopContainer() = delete;
TopContainer(int32 k,int32 row_size)96   TopContainer(int32 k, int32 row_size) : k_(k) {
97     container_.reserve(std::min(k, row_size) + 1);
98   }
99 
start_collecting(const T * values)100   void start_collecting(const T* values) {
101     values_ = values;
102     container_.clear();
103     is_heap_ = false;
104   }
105 
push(int32 a)106   void push(int32 a) {
107     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
108     if (!is_heap_) {
109       container_.push_back(a);
110       if (container_.size() == k_ + 1) {
111         std::make_heap(container_.begin(), container_.end(), comparator);
112         std::pop_heap(container_.begin(), container_.end(), comparator);
113         container_.pop_back();
114         is_heap_ = true;
115       }
116     } else if (comparator(a, container_.front())) {
117       // Due to how we defined comparator / compare_fun, container_.front()
118       // contains the index of the smallest of the top-k elements seen so far.
119       //
120       // If control reaches this point, we know that the current index a
121       // corresponds to an element which is bigger than the smallest of the
122       // top-k elements seen so far.  Hence, we have to update the indices of
123       // the top-k elements, by removing the index of the smallest top-k
124       // element, adding a, and making sure container_[0:k] is still a heap.
125       std::pop_heap(container_.begin(), container_.end(), comparator);
126       container_.back() = a;
127       std::push_heap(container_.begin(), container_.end(), comparator);
128     }
129   }
130 
sorted_result()131   const std::vector<int32>& sorted_result() {
132     auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); };
133     if (!is_heap_) {
134       // Note: due to the way we defined compare_fun (see comments for that
135       // function) std::sort puts the indices from container_ in decreasing
136       // order of the corresponding elements.
137       std::sort(container_.begin(), container_.end(), comparator);
138     } else {
139       std::sort_heap(container_.begin(), container_.end(), comparator);
140     }
141     return container_;
142   }
143 
144  private:
145   const int32 k_;
146 
147   // container_[0,k) holds the indices of the largest k elements from values_
148   // seen so far.  If more than k elements are pushed, then elements are
149   // maintained in a min-heap order: container_.front() is
150   // the index of the smallest of the top-k elements see so far.
151   std::vector<int32> container_;
152 
153   // Once more than k elements are pushed, the container becomes a min heap,
154   // and is_heap_ becomes true.
155   bool is_heap_ = false;
156 
157   const T* values_ = nullptr;
158 
159   // Compares indices a and b based on the corresponding elements from values_.
160   //
161   // Intuitively, compare_fun(a, b) returns true iff values_[b] < values_[a]
162   // (notice the inversion of direction, not a typo); ties (==) are broken in
163   // favor of earlier elements (i.e., a < b).
compare_fun(int32 a,int32 b) const164   bool compare_fun(int32 a, int32 b) const {
165     if (values_[b] < values_[a]) {
166       return true;
167     } else if (values_[b] > values_[a]) {
168       return false;
169     } else {
170       return a < b;
171     }
172   }
173 };
174 
175 // Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU.
176 template <typename T>
TopK(int32 row_size,int32 num_rows,const T * data,int32 k,int32 * output_indexes,T * output_values)177 void TopK(int32 row_size, int32 num_rows, const T* data, int32 k,
178           int32* output_indexes, T* output_values) {
179   TopContainer<T> topc(k, row_size);
180   for (int row = 0; row < num_rows; ++row) {
181     const T* values_row = data + row * row_size;
182     topc.start_collecting(values_row);
183     for (int32 c = 0; c < row_size; ++c) {
184       topc.push(c);
185     }
186 
187     // Prepare output buffers.
188     int32* indexes_row = output_indexes + row * k;
189     T* output_row = output_values + row * k;
190     // We always assume that the output is sorted.
191     const auto& top_k = topc.sorted_result();
192     std::copy(top_k.begin(), top_k.end(), indexes_row);
193     std::transform(top_k.begin(), top_k.end(), output_row,
194                    [values_row](const int32 loc) { return values_row[loc]; });
195   }
196 }
197 
198 }  // namespace
199 
Prepare(TfLiteContext * context,TfLiteNode * node)200 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
201   // Check that the inputs and outputs have the right sizes and types.
202   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
203   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
204 
205   const TfLiteTensor* input;
206   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
207   TfLiteTensor* output_values;
208   TF_LITE_ENSURE_OK(
209       context, GetOutputSafe(context, node, kOutputValues, &output_values));
210   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output_values->type);
211 
212   const TfLiteTensor* top_k;
213   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
214   TF_LITE_ENSURE_TYPES_EQ(context, top_k->type, kTfLiteInt32);
215 
216   // Set output dynamic if the `top_k` tensor is not constant, or the input has
217   // dynamic dimensions (indicated by dims signature).
218   if (IsConstantTensor(top_k) && !HasUnspecifiedDimension(input)) {
219     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
220   } else {
221     TfLiteTensor* output_indexes;
222     TF_LITE_ENSURE_OK(
223         context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
224     TfLiteTensor* output_values;
225     TF_LITE_ENSURE_OK(
226         context, GetOutputSafe(context, node, kOutputValues, &output_values));
227     SetTensorToDynamic(output_indexes);
228     SetTensorToDynamic(output_values);
229   }
230   return kTfLiteOk;
231 }
232 
Eval(TfLiteContext * context,TfLiteNode * node)233 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
234   TfLiteTensor* output_values;
235   TF_LITE_ENSURE_OK(
236       context, GetOutputSafe(context, node, kOutputValues, &output_values));
237   TfLiteTensor* output_indexes;
238   TF_LITE_ENSURE_OK(
239       context, GetOutputSafe(context, node, kOutputIndexes, &output_indexes));
240   if (IsDynamicTensor(output_values)) {
241     TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
242   }
243   const TfLiteTensor* top_k;
244   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTopK, &top_k));
245   const int32 k = top_k->data.i32[0];
246   // The tensor can have more than 2 dimensions or even be a vector, the code
247   // anyway calls the internal dimension as row;
248   const TfLiteTensor* input;
249   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
250   const int32 row_size = input->dims->data[input->dims->size - 1];
251   int32 num_rows = 1;
252   for (int i = 0; i < input->dims->size - 1; ++i) {
253     num_rows *= input->dims->data[i];
254   }
255   switch (output_values->type) {
256     case kTfLiteFloat32:
257       TopK(row_size, num_rows, GetTensorData<float>(input), k,
258            output_indexes->data.i32, GetTensorData<float>(output_values));
259       break;
260     case kTfLiteUInt8:
261       TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32,
262            output_values->data.uint8);
263       break;
264     case kTfLiteInt8:
265       TopK(row_size, num_rows, input->data.int8, k, output_indexes->data.i32,
266            output_values->data.int8);
267       break;
268     case kTfLiteInt32:
269       TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32,
270            output_values->data.i32);
271       break;
272     case kTfLiteInt64:
273       TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32,
274            output_values->data.i64);
275       break;
276     default:
277       TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by TopK.",
278                          TfLiteTypeGetName(output_values->type));
279       return kTfLiteError;
280   }
281 
282   return kTfLiteOk;
283 }
284 }  // namespace topk_v2
Register_TOPK_V2()285 TfLiteRegistration* Register_TOPK_V2() {
286   static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare,
287                                  topk_v2::Eval};
288   return &r;
289 }
290 }  // namespace builtin
291 }  // namespace ops
292 }  // namespace tflite
293