1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/topk_op.h" 21 22 #include <algorithm> 23 #include <numeric> 24 #include <vector> 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/top_n.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 typedef Eigen::GpuDevice GPUDevice; 38 39 template <typename Device, typename T> 40 class TopK : public OpKernel { 41 public: TopK(OpKernelConstruction * context)42 explicit TopK(OpKernelConstruction* context) : OpKernel(context) { 43 OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); 44 if (num_inputs() < 2) { // k is an attr (TopK). 45 OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); 46 } else { // k is an input (TopKV2), so we won't know it until Compute. 47 k_ = -1; 48 } 49 } 50 Compute(OpKernelContext * context)51 void Compute(OpKernelContext* context) override { 52 int k = k_; 53 if (num_inputs() >= 2) { 54 const auto& k_in = context->input(1); 55 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), 56 errors::InvalidArgument("k must be scalar, got shape ", 57 k_in.shape().DebugString())); 58 k = k_in.scalar<int32>()(); 59 } 60 OP_REQUIRES(context, k >= 0, 61 errors::InvalidArgument("Need k >= 0, got ", k)); 62 const auto& input_in = context->input(0); 63 OP_REQUIRES(context, input_in.dims() >= 1, 64 errors::InvalidArgument("input must be >= 1-D, got shape ", 65 input_in.shape().DebugString())); 66 OP_REQUIRES(context, input_in.dim_size(input_in.dims() - 1) >= k, 67 errors::InvalidArgument( 68 "input must have at least k columns. Had ", 69 input_in.dim_size(input_in.dims() - 1), ", needed ", k)); 70 71 const auto& input = input_in.flat_inner_dims<T>(); 72 73 const int64_t num_rows = input.dimension(0); // generally batch_size 74 const int64_t num_cols = input.dimension(1); 75 OP_REQUIRES( 76 context, num_rows <= std::numeric_limits<int32>::max(), 77 errors::InvalidArgument( 78 "First dimension of flattened input must be <= INT_MAX, got ", 79 num_rows)); 80 OP_REQUIRES( 81 context, num_cols <= std::numeric_limits<int32>::max(), 82 errors::InvalidArgument( 83 "Second dimension of flattened input must be <= INT_MAX, got ", 84 num_cols)); 85 86 TensorShape output_shape = input_in.shape(); 87 output_shape.set_dim(input_in.dims() - 1, k); 88 Tensor* values_out = nullptr; 89 OP_REQUIRES_OK(context, 90 context->allocate_output(0, output_shape, &values_out)); 91 Tensor* indices_out = nullptr; 92 OP_REQUIRES_OK(context, 93 context->allocate_output(1, output_shape, &indices_out)); 94 95 // Nothing to do for top-nothing or over nothing. 96 if (k == 0 || num_rows == 0) return; 97 98 auto values = values_out->flat_inner_dims<T>(); 99 auto indices = indices_out->flat_inner_dims<int32>(); 100 Status s = functor::TopKFunctor<Device, T>::Compute( 101 context, sorted_, k, input, num_rows, num_cols, values, indices); 102 OP_REQUIRES_OK(context, s); 103 } 104 105 private: 106 int k_; 107 bool sorted_; 108 }; 109 110 namespace functor { 111 112 template <typename T> 113 struct TopKFunctor<CPUDevice, T> { Computetensorflow::functor::TopKFunctor114 static EIGEN_ALWAYS_INLINE Status Compute( 115 OpKernelContext* context, bool sorted, int k, 116 const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, 117 const int64_t num_cols, typename TTypes<T, 2>::Tensor values, 118 typename TTypes<int, 2>::Tensor indices) { 119 const CPUDevice& d = context->eigen_device<CPUDevice>(); 120 121 // Special case for k == 1. 122 if (k == 1) { 123 typename Eigen::IndexList<Eigen::type2index<1>> reduce_on_cols; 124 typename Eigen::IndexList<int, Eigen::type2index<1>> rows_by_one; 125 rows_by_one.set(0, num_rows); 126 127 values.device(d) = 128 input.maximum(/*dims=*/reduce_on_cols).eval().reshape(rows_by_one); 129 // Get the indices of the maximum values. 130 for (int r = 0; r < num_rows; ++r) { 131 indices(r, 0) = 0; 132 for (int c = 0; c < num_cols; ++c) { 133 if (values(r, 0) == input(r, c)) { 134 indices(r, 0) = c; 135 break; 136 } 137 } 138 values(r, 0) = input(r, indices(r, 0)); 139 } 140 141 return OkStatus(); 142 } 143 144 auto SortIndices = [&](int64_t start_batch, int64_t limit_batch) { 145 for (int32_t b = start_batch; b < limit_batch; ++b) { 146 const T* input_data = &input(b, 0); 147 const auto stable_comp = [input_data](const int32_t a, 148 const int32_t b) { 149 if (input_data[b] < input_data[a]) { 150 return true; 151 } else if (input_data[b] > input_data[a]) { 152 return false; 153 } else { 154 return a < b; 155 } 156 }; 157 const auto comp = [input_data](const int32_t a, const int32_t b) { 158 return input_data[b] < input_data[a]; 159 }; 160 // TODO(ebrevdo): For large k < num_cols, instead of using 161 // TopN, it may be faster to create a temporary vector of 162 // values 0..num_cols - 1 and then use std::partial_sort_copy 163 // of this into indices. Choosing the appropriate minimum k or 164 // ratio of k/num_cols will require some experimentation. 165 if (k == num_cols) { 166 auto* begin = &indices(b, 0); 167 auto* end = &indices(b, k); 168 // Set the initial array of indices 0 ... k - 1. 169 std::iota(begin, end, 0); 170 // We want an in-place sort, but we can cheat because we're sorting 171 // indices that started out sorted. First, do a std::sort, which 172 // is notably faster than std::stable_sort. 173 std::sort(begin, end, comp); 174 // Then, for runs of adjacent elements that were equal, sort the 175 // indices in those runs in increasing order. 176 for (auto* run_begin = begin; run_begin != end;) { 177 auto* run_end = run_begin + 1; 178 if (run_end == end) break; 179 if (input_data[*run_begin] == input_data[*run_end]) { 180 while (++run_end != end) { 181 if (input_data[*run_begin] != input_data[*run_end]) break; 182 } 183 std::sort(run_begin, run_end); 184 } 185 run_begin = run_end; 186 } 187 } else { 188 // Use the TopN heap object to sort. 189 gtl::TopN<int32, decltype(stable_comp)> filter(k, stable_comp); 190 filter.reserve(num_cols); 191 for (int32_t c = 0; c < num_cols; ++c) { 192 filter.push(c); 193 } 194 195 int32_t i = 0; 196 if (sorted) { 197 std::unique_ptr<std::vector<int32>> top_k(filter.Extract()); 198 for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); 199 ++top_k_it, ++i) { 200 indices(b, i) = *top_k_it; 201 } 202 } else { 203 for (auto top_k_it = filter.unsorted_begin(); 204 top_k_it != filter.unsorted_end(); ++top_k_it, ++i) { 205 indices(b, i) = *top_k_it; 206 } 207 } 208 } 209 // Now that the indices are sorted, copy the values over in 210 // sorted order. 211 std::transform( 212 &indices(b, 0), &indices(b, k), &values(b, 0), 213 [b, &input](const int32_t loc) { return input(b, loc); }); 214 } // for (int32 b = ... 215 }; 216 217 // Guesstimate of cost; 4*N*log(K) where N == num_cols. 218 // If K == N, assume the cost is N*log(K + 1). 219 const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() + 220 Eigen::TensorOpCost::AddCost<T>(); 221 const double base_cost = 222 cmp_cost * 223 static_cast<double>(num_cols * 224 Eigen::numext::log2(static_cast<float>(k + 1))); 225 const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; 226 const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>(); 227 const double total_cost = sort_cost + copy_cost; 228 const int64_t final_cost = (total_cost >= static_cast<double>(kint64max)) 229 ? kint64max 230 : static_cast<int64_t>(total_cost); 231 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 232 Shard(worker_threads.num_threads, worker_threads.workers, num_rows, 233 final_cost, SortIndices); 234 235 return OkStatus(); 236 } 237 }; 238 239 } // namespace functor 240 241 #define REGISTER_KERNELS_NAME(name, type) \ 242 REGISTER_KERNEL_BUILDER( \ 243 Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 244 TopK<CPUDevice, type>) 245 246 #define REGISTER_KERNELS(type) \ 247 REGISTER_KERNELS_NAME(TopK, type); \ 248 REGISTER_KERNELS_NAME(TopKV2, type) 249 250 TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); 251 #undef REGISTER_KERNELS_NAME 252 #undef REGISTER_KERNELS 253 254 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 255 256 namespace functor { 257 #define DECLARE_GPU_SPEC(T) \ 258 template <> \ 259 Status TopKFunctor<GPUDevice, T>::Compute( \ 260 OpKernelContext* context, bool sorted, int k, \ 261 const typename TTypes<T, 2>::ConstTensor& input, const int64_t num_rows, \ 262 const int64_t num_cols, typename TTypes<T, 2>::Tensor values, \ 263 typename TTypes<int, 2>::Tensor indices); \ 264 extern template struct functor::TopKFunctor<GPUDevice, T>; 265 266 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 267 TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); 268 269 #undef DECLARE_GPU_SPEC 270 271 } // namespace functor 272 273 #define REGISTER_KERNELS(type) \ 274 REGISTER_KERNEL_BUILDER( \ 275 Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 276 TopK<GPUDevice, type>) \ 277 REGISTER_KERNEL_BUILDER(Name("TopKV2") \ 278 .Device(DEVICE_GPU) \ 279 .TypeConstraint<type>("T") \ 280 .HostMemory("k"), \ 281 TopK<GPUDevice, type>) 282 283 TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); 284 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); 285 #undef REGISTER_KERNELS 286 287 #endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM 288 289 } // end namespace tensorflow 290