xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/topk_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/topk_op.h"
21 
22 #include <algorithm>
23 #include <numeric>
24 #include <vector>
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/top_n.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef Eigen::GpuDevice GPUDevice;
38 
39 template <typename Device, typename T>
40 class TopK : public OpKernel {
41  public:
TopK(OpKernelConstruction * context)42   explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
43     OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
44     if (num_inputs() < 2) {  // k is an attr (TopK).
45       OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
46     } else {  // k is an input (TopKV2), so we won't know it until Compute.
47       k_ = -1;
48     }
49   }
50 
Compute(OpKernelContext * context)51   void Compute(OpKernelContext* context) override {
52     int k = k_;
53     if (num_inputs() >= 2) {
54       const auto& k_in = context->input(1);
55       OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
56                   errors::InvalidArgument("k must be scalar, got shape ",
57                                           k_in.shape().DebugString()));
58       k = k_in.scalar<int32>()();
59     }
60     OP_REQUIRES(context, k >= 0,
61                 errors::InvalidArgument("Need k >= 0, got ", k));
62     const auto& input_in = context->input(0);
63     OP_REQUIRES(context, input_in.dims() >= 1,
64                 errors::InvalidArgument("input must be >= 1-D, got shape ",
65                                         input_in.shape().DebugString()));
66     OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k,
67                 errors::InvalidArgument(
68                     "input must have at least k columns. Had ",
69                     input_in.dim_size(input_in.dims() - 1), ", needed ", k));
70 
71     const auto& input = input_in.flat_inner_dims<T>();
72 
73     const int64_t num_rows = input.dimension(0);  // generally batch_size
74     const int64_t num_cols = input.dimension(1);
75     OP_REQUIRES(
76         context, num_rows <= std::numeric_limits<int32>::max(),
77         errors::InvalidArgument(
78             "First dimension of flattened input must be <= INT_MAX, got ",
79             num_rows));
80     OP_REQUIRES(
81         context, num_cols <= std::numeric_limits<int32>::max(),
82         errors::InvalidArgument(
83             "Second dimension of flattened input must be <= INT_MAX, got ",
84             num_cols));
85 
86     TensorShape output_shape = input_in.shape();
87     output_shape.set_dim(input_in.dims() - 1, k);
88     Tensor* values_out = nullptr;
89     OP_REQUIRES_OK(context,
90                    context->allocate_output(0, output_shape, &values_out));
91     Tensor* indices_out = nullptr;
92     OP_REQUIRES_OK(context,
93                    context->allocate_output(1, output_shape, &indices_out));
94 
95     // Nothing to do for top-nothing or over nothing.
96     if (k == 0 || num_rows == 0) return;
97 
98     auto values = values_out->flat_inner_dims<T>();
99     auto indices = indices_out->flat_inner_dims<int32>();
100     Status s = functor::TopKFunctor<Device, T>::Compute(
101         context, sorted_, k, input, num_rows, num_cols, values, indices);
102     OP_REQUIRES_OK(context, s);
103   }
104 
105  private:
106   int k_;
107   bool sorted_;
108 };
109 
110 namespace functor {
111 
112 template <typename T>
113 struct TopKFunctor<CPUDevice, T> {
Computetensorflow::functor::TopKFunctor114   static EIGEN_ALWAYS_INLINE Status Compute(
115       OpKernelContext* context, bool sorted, int k,
116       const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows,
117       const int64_t num_cols, typename TTypes<T, 2>::Tensor values,
118       typename TTypes<int, 2>::Tensor indices) {
119     const CPUDevice& d = context->eigen_device<CPUDevice>();
120 
121     // Special case for k == 1.
122     if (k == 1) {
123       typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols;
124       typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one;
125       rows_by_one.set(0, num_rows);
126 
127       values.device(d) =
128           input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one);
129       // Get the indices of the maximum values.
130       for (int r = 0; r < num_rows; ++r) {
131         indices(r, 0) = 0;
132         for (int c = 0; c < num_cols; ++c) {
133           if (values(r, 0) == input(r, c)) {
134             indices(r, 0) = c;
135             break;
136           }
137         }
138         values(r, 0) = input(r, indices(r, 0));
139       }
140 
141       return OkStatus();
142     }
143 
144     auto SortIndices = [&](int64_t start_batch, int64_t limit_batch) {
145       for (int32_t b = start_batch; b < limit_batch; ++b) {
146         const T* input_data = &input(b, 0);
147         const auto stable_comp = [input_data](const int32_t a,
148                                               const int32_t b) {
149           if (input_data[b] < input_data[a]) {
150             return true;
151           } else if (input_data[b] > input_data[a]) {
152             return false;
153           } else {
154             return a < b;
155           }
156         };
157         const auto comp = [input_data](const int32_t a, const int32_t b) {
158           return input_data[b] < input_data[a];
159         };
160         // TODO(ebrevdo): For large k < num_cols, instead of using
161         // TopN, it may be faster to create a temporary vector of
162         // values 0..num_cols - 1 and then use std::partial_sort_copy
163         // of this into indices. Choosing the appropriate minimum k or
164         // ratio of k/num_cols will require some experimentation.
165         if (k == num_cols) {
166           auto* begin = &indices(b, 0);
167           auto* end = &indices(b, k);
168           // Set the initial array of indices 0 ... k - 1.
169           std::iota(begin, end, 0);
170           // We want an in-place sort, but we can cheat because we're sorting
171           // indices that started out sorted.  First, do a std::sort, which
172           // is notably faster than std::stable_sort.
173           std::sort(begin, end, comp);
174           // Then, for runs of adjacent elements that were equal, sort the
175           // indices in those runs in increasing order.
176           for (auto* run_begin = begin; run_begin != end;) {
177             auto* run_end = run_begin + 1;
178             if (run_end == end) break;
179             if (input_data[*run_begin] == input_data[*run_end]) {
180               while (++run_end != end) {
181                 if (input_data[*run_begin] != input_data[*run_end]) break;
182               }
183               std::sort(run_begin, run_end);
184             }
185             run_begin = run_end;
186           }
187         } else {
188           // Use the TopN heap object to sort.
189           gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp);
190           filter.reserve(num_cols);
191           for (int32_t c = 0; c < num_cols; ++c) {
192             filter.push(c);
193           }
194 
195           int32_t i = 0;
196           if (sorted) {
197             std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
198             for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
199                  ++top_k_it, ++i) {
200               indices(b, i) = *top_k_it;
201             }
202           } else {
203             for (auto top_k_it = filter.unsorted_begin();
204                  top_k_it != filter.unsorted_end(); ++top_k_it, ++i) {
205               indices(b, i) = *top_k_it;
206             }
207           }
208         }
209         // Now that the indices are sorted, copy the values over in
210         // sorted order.
211         std::transform(
212             &indices(b, 0), &indices(b, k), &values(b, 0),
213             [b, &input](const int32_t loc) { return input(b, loc); });
214       }  // for (int32 b = ...
215     };
216 
217     // Guesstimate of cost; 4*N*log(K) where N == num_cols.
218     // If K == N, assume the cost is N*log(K + 1).
219     const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
220                             Eigen::TensorOpCost::AddCost<T>();
221     const double base_cost =
222         cmp_cost *
223         static_cast<double>(num_cols *
224                             Eigen::numext::log2(static_cast<float>(k + 1)));
225     const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
226     const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
227     const double total_cost = sort_cost + copy_cost;
228     const int64_t final_cost = (total_cost >= static_cast<double>(kint64max))
229                                    ? kint64max
230                                    : static_cast<int64_t>(total_cost);
231     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
232     Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
233           final_cost, SortIndices);
234 
235     return OkStatus();
236   }
237 };
238 
239 }  // namespace functor
240 
241 #define REGISTER_KERNELS_NAME(name, type)                       \
242   REGISTER_KERNEL_BUILDER(                                      \
243       Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
244       TopK<CPUDevice, type>)
245 
246 #define REGISTER_KERNELS(type)       \
247   REGISTER_KERNELS_NAME(TopK, type); \
248   REGISTER_KERNELS_NAME(TopKV2, type)
249 
250 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
251 #undef REGISTER_KERNELS_NAME
252 #undef REGISTER_KERNELS
253 
254 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
255 
256 namespace functor {
257 #define DECLARE_GPU_SPEC(T)                                                    \
258   template <>                                                                  \
259   Status TopKFunctor<GPUDevice, T>::Compute(                                   \
260       OpKernelContext* context, bool sorted, int k,                            \
261       const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, \
262       const int64_t num_cols, typename TTypes<T, 2>::Tensor values,            \
263       typename TTypes<int, 2>::Tensor indices);                                \
264   extern template struct functor::TopKFunctor<GPUDevice, T>;
265 
266 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
267 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
268 
269 #undef DECLARE_GPU_SPEC
270 
271 }  // namespace functor
272 
273 #define REGISTER_KERNELS(type)                                   \
274   REGISTER_KERNEL_BUILDER(                                       \
275       Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
276       TopK<GPUDevice, type>)                                     \
277   REGISTER_KERNEL_BUILDER(Name("TopKV2")                         \
278                               .Device(DEVICE_GPU)                \
279                               .TypeConstraint<type>("T")         \
280                               .HostMemory("k"),                  \
281                           TopK<GPUDevice, type>)
282 
283 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
284 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
285 #undef REGISTER_KERNELS
286 
287 #endif  // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
288 
289 }  // end namespace tensorflow
290