xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/image/non_max_suppression_op.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #define EIGEN_USE_GPU
18 #include "tensorflow/core/kernels/image/non_max_suppression_op.h"
19 
20 #include <limits>
21 
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/numeric_types.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/kernels/gpu_prim.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/util/gpu_kernel_helper.h"
30 #include "tensorflow/core/util/gpu_launch_config.h"
31 #include "tensorflow/stream_executor/stream_executor.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 struct
37 #if GOOGLE_CUDA
38     __align__(16)
39 #endif
40         Box {
41   float x1, y1, x2, y2;
42 };
43 typedef Eigen::GpuDevice GPUDevice;
44 typedef Eigen::ThreadPoolDevice CPUDevice;
45 
46 // This is the width of the bitmask for masking boxes for each thread.  This
47 // needs to be a multiple of 2 (a POD width usually) so that division and modulo
48 // can be implemented as bit operations during host selection.
49 constexpr int kNmsBoxesPerThread = 8 * sizeof(int);
50 
51 // Helper to calculate modulo mask and shift bits.
52 //
53 // For kNmsBoxesPerThread=32 ModuloMask will be 31, i.e 0x1F, thus
54 // i % 32 == i & 31. Similarly ShiftBits will be 5 so that
55 // i / 32 == i >> 5. Using these bit operations should reduce the stall on host
56 // thread.
NumBits(int n)57 constexpr int NumBits(int n) { return (n == 0) ? 0 : NumBits(n >> 1) + 1; }
58 constexpr int kNmsBoxesPerThreadModuloMask = kNmsBoxesPerThread - 1;
59 constexpr int kNmsBoxesPerThreadShiftBits =
60     NumBits(kNmsBoxesPerThreadModuloMask);
61 
62 constexpr int kNmsBlockDim = 16;
63 constexpr int kNmsBlockDimMax = 128;
64 constexpr int kNmsChunkSize = 2000;
65 
66 template <typename T>
Swap(T & a,T & b)67 __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) {
68   T c(a);
69   a = b;
70   b = c;
71 }
72 
73 // Check whether two boxes have an IoU greater than threshold.
74 template <typename T>
OverThreshold(const Box * a,const Box * b,const float a_area,const T iou_threshold)75 __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b,
76                                                   const float a_area,
77                                                   const T iou_threshold) {
78   const float b_area = (b->x2 - b->x1) * (b->y2 - b->y1);
79   if (a_area == 0.0f || b_area == 0.0f) return false;
80   const float xx1 = fmaxf(a->x1, b->x1);
81   const float yy1 = fmaxf(a->y1, b->y1);
82   const float xx2 = fminf(a->x2, b->x2);
83   const float yy2 = fminf(a->y2, b->y2);
84 
85   // fdimf computes the positive difference between xx2+1 and xx1.
86   const float w = fdimf(xx2, xx1);
87   const float h = fdimf(yy2, yy1);
88   const float intersection = w * h;
89 
90   // Testing for aa/bb > t
91   // eq with aa > bb*t (b is !=0)
92   // avoiding divisions.
93   const float aa = intersection;
94   const float bb = a_area + b_area - intersection;
95   const float bt = bb * iou_threshold;
96   return aa > bt;
97 }
98 
99 template <bool flip_box>
100 __device__ EIGEN_STRONG_INLINE void Flipped(Box& box);
101 
102 template <>
Flipped(Box & box)103 __device__ EIGEN_STRONG_INLINE void Flipped<false>(Box& box) {}
104 
105 template <>
Flipped(Box & box)106 __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
107   if (box.x1 > box.x2) Swap(box.x1, box.x2);
108   if (box.y1 > box.y2) Swap(box.y1, box.y2);
109 }
110 template <typename T>
CheckBit(T * bit_mask,uint32 bit)111 __device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, uint32 bit) {
112   constexpr uint32 kNumBits = 8 * sizeof(T);
113   return (bit_mask[bit / kNumBits] >> (bit % kNumBits)) & 1;
114 }
115 
116 // Produce a global bitmask (result_mask) of selected boxes from bitmask
117 // generated by NMSKernel. Abort early if max_boxes boxes are selected. Bitmask
118 // is num_boxes*bit_mask_len bits indicating whether to keep or remove a box.
NMSReduce(const int * bitmask,const int bit_mask_len,const int num_boxes,const int max_boxes,char * result_mask)119 __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
120                           const int num_boxes, const int max_boxes,
121                           char* result_mask) {
122   extern __shared__ int local[];
123   // Set global mask to accept all boxes.
124   for (int box : GpuGridRangeX(bit_mask_len)) {
125     local[box] = 0xFFFFFFFF;
126   }
127   __syncthreads();
128 
129   int accepted_boxes = 0;
130   for (int box = 0; box < num_boxes - 1; ++box) {
131     // If current box is masked by an earlier box, skip it.
132     if (!CheckBit(local, box)) {
133       continue;
134     }
135     accepted_boxes += 1;
136     int offset = box * bit_mask_len;
137     // Update global mask with current box's mask.
138     for (int b : GpuGridRangeX(bit_mask_len)) {
139       local[b] &= ~bitmask[offset + b];
140     }
141     __syncthreads();
142     if (accepted_boxes > max_boxes) break;
143   }
144 
145   // Copy global mask to result_max char array, which we use in
146   // cub::DeviceSelect later.  In theory we could skip this test and use the
147   // bitmask in DeviceSelect directly, but in practice this part of the kernel
148   // is very cheap anyway.
149   for (int box : GpuGridRangeX(num_boxes)) {
150     result_mask[box] = CheckBit(local, box);
151   }
152 }
153 
154 // For each box, compute a bitmask of boxes which has an overlap with given box
155 // above threshold.
156 //
157 // Starting from highest scoring box, mark any box which has IoU>threshold with
158 // given box. Each thread processes a kNmsBoxesPerThread boxes per stride, and
159 // each box has bitmask of overlaps of length bit_mask_len.
160 //
161 // If flip_box is true boxes may have x1>x2 and or y1>y2. If so change the
162 // coordinates such that for all boxes x1<x2 and y1<y2. Else boxes should have
163 // x1<x2 and y1<y2.
164 template <bool flip_box>
165 __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
NMSKernel(const Box * d_desc_sorted_boxes,const int num_boxes,const float iou_threshold,const int bit_mask_len,int * d_delete_mask)166     void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes,
167                    const float iou_threshold, const int bit_mask_len,
168                    int* d_delete_mask) {
169   // Storing boxes used by this CUDA block in the shared memory.
170   __shared__ Box shared_i_boxes[kNmsBlockDim];
171   // Same thing with areas
172   __shared__ float shared_i_areas[kNmsBlockDim];
173   // The condition of the for loop is common to all threads in the block.
174   // This is necessary to be able to call __syncthreads() inside of the loop.
175   for (int i_block_offset = blockIdx.x * blockDim.x; i_block_offset < num_boxes;
176        i_block_offset += blockDim.x * gridDim.x) {
177     const int i = i_block_offset + threadIdx.x;
178     if (i < num_boxes) {
179       // One 1D line load the boxes for x-dimension.
180       if (threadIdx.y == 0) {
181         Box box = d_desc_sorted_boxes[i];
182         Flipped<flip_box>(box);
183         shared_i_boxes[threadIdx.x] = box;
184         shared_i_areas[threadIdx.x] = (box.x2 - box.x1) * (box.y2 - box.y1);
185       }
186     }
187     __syncthreads();
188     for (int j_thread_offset =
189              kNmsBoxesPerThread * (blockIdx.y * blockDim.y + threadIdx.y);
190          j_thread_offset < num_boxes;
191          j_thread_offset += kNmsBoxesPerThread * blockDim.y * gridDim.y) {
192       // Note : We can do everything using multiplication,
193       // and use fp16 - we are comparing against a low precision
194       // threshold.
195       int above_threshold = 0;
196       // Make sure that threads are within valid domain.
197       bool valid = false;
198       // Loop over the next kNmsBoxesPerThread boxes and set corresponding bit
199       // if it is overlapping with current box
200       for (int ib = 0; ib < kNmsBoxesPerThread; ++ib) {
201         // This thread will compare Box i and Box j.
202         const int j = j_thread_offset + ib;
203         if (i >= j || i >= num_boxes || j >= num_boxes) continue;
204         valid = true;
205         Box j_box = d_desc_sorted_boxes[j];
206         const Box i_box = shared_i_boxes[threadIdx.x];
207         Flipped<flip_box>(j_box);
208         if (OverThreshold<float>(&i_box, &j_box, shared_i_areas[threadIdx.x],
209                                  iou_threshold)) {
210           // we have score[j] <= score[i].
211           above_threshold |= (1U << ib);
212         }
213       }
214       if (valid) {
215         d_delete_mask[i * bit_mask_len + j_thread_offset / kNmsBoxesPerThread] =
216             above_threshold;
217       }
218     }
219     __syncthreads();  // making sure everyone is done reading shared memory.
220   }
221 }
222 // Variadic template helpers for Index selecting multiple arrays at the same
223 // time
224 template <typename Index>
SelectHelper(const Index i_selected,const Index i_original)225 __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
226                                                  const Index i_original) {}
227 
228 template <typename Index, typename T, typename... Args>
SelectHelper(const Index i_selected,const Index i_original,const T * original,T * selected,Args...args)229 __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
230                                                  const Index i_original,
231                                                  const T* original, T* selected,
232                                                  Args... args) {
233   selected[i_selected] = original[i_original];
234   SelectHelper(i_selected, i_original, args...);
235 }
236 
237 // Helper template to select elements from original arrays using the index
238 // mapping and store into selected array. Each array sharing same mapping need
239 // to be passed as pairs of pointers to original and selected arrays. For
240 // selecting 2 arrays call would be
241 // IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
242 // selected2).
243 template <typename Index, typename T, typename... Args>
IndexMultiSelect(const int num_elements,const Index * indices,const T * original,T * selected,Args...args)244 __global__ void IndexMultiSelect(const int num_elements, const Index* indices,
245                                  const T* original, T* selected, Args... args) {
246   for (const int idx : GpuGridRangeX(num_elements)) {
247     SelectHelper(idx, indices[idx], original, selected, args...);
248   }
249 }
250 
251 template <typename T>
Iota(const int num_elements,const T offset,T * to_fill)252 __global__ void Iota(const int num_elements, const T offset, T* to_fill) {
253   for (int idx : GpuGridRangeX(num_elements)) {
254     to_fill[idx] = static_cast<T>(idx) + offset;
255   }
256 }
257 
258 // TensorFlow with nvcc doesn't build with --extended-lambda, so we have to use
259 // an explicit functor instead of a device lambda.
260 struct GreaterThanCubOp {
261   float threshold_;
GreaterThanCubOptensorflow::__anon6cfbad0a0111::GreaterThanCubOp262   __host__ __device__ __forceinline__ GreaterThanCubOp(float threshold)
263       : threshold_(threshold) {}
operator ()tensorflow::__anon6cfbad0a0111::GreaterThanCubOp264   __host__ __device__ __forceinline__ bool operator()(const float& val) const {
265     return (val > threshold_);
266   }
267 };
268 
269 // Uses DeviceSelect::If to count number of elements.
270 //
271 // (It might be better to use DeviceReduce::Sum with a custom iterator to do the
272 // count.  But in practice SelectIf is quite fast.)
273 template <typename Op>
CountIf(OpKernelContext * context,const float * dev_array,const Op & op,int num_elements)274 StatusOr<int> CountIf(OpKernelContext* context, const float* dev_array,
275                       const Op& op, int num_elements) {
276   size_t workspace_size = 0;
277   auto cuda_stream = tensorflow::GetGpuStream(context);
278   auto device = context->eigen_gpu_device();
279   gpuprim::DeviceSelect::If(nullptr, workspace_size,
280                             static_cast<float*>(nullptr),
281                             static_cast<float*>(nullptr),
282                             static_cast<int*>(nullptr), num_elements, op);
283 
284   Tensor scratch_output;
285   TF_RETURN_IF_ERROR(context->allocate_temp(
286       DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
287 
288   Tensor workspace;
289   TF_RETURN_IF_ERROR(context->allocate_temp(
290       DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
291 
292   // num_selected is a host pinned tensor.  The GPU kernel can write to it
293   // directly, instead of writing to GPU memory and then copying down to
294   // num_selected, saving us a small D2H memcpy.  We've observed that even small
295   // D2H copies on the compute stream can have an outsized effect on latency.
296   Tensor num_selected;
297   AllocatorAttributes pinned_alloc_attrs;
298   pinned_alloc_attrs.set_on_host(true);
299   pinned_alloc_attrs.set_gpu_compatible(true);
300   TF_RETURN_IF_ERROR(context->allocate_temp(
301       DataType::DT_INT32, TensorShape({1}), &num_selected, pinned_alloc_attrs));
302 
303   gpuEvent_t copy_done;
304   TF_RETURN_IF_CUDA_ERROR(
305       gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
306   TF_RETURN_IF_CUDA_ERROR(gpuprim::DeviceSelect::If(
307       workspace.flat<int8>().data(), workspace_size, dev_array,
308       scratch_output.flat<float>().data(), num_selected.flat<int32>().data(),
309       num_elements, op, cuda_stream));
310   TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
311   TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
312   return *num_selected.flat<int32>().data();
313 }
314 
DoNMS(OpKernelContext * context,const Tensor & boxes,const Tensor & scores,const int64_t max_output_size,const float iou_threshold_val,const float score_threshold,bool pad_to_max_output,int * num_saved_outputs)315 Status DoNMS(OpKernelContext* context, const Tensor& boxes,
316              const Tensor& scores, const int64_t max_output_size,
317              const float iou_threshold_val, const float score_threshold,
318              bool pad_to_max_output, int* num_saved_outputs) {
319   int num_boxes = boxes.dim_size(0);
320   size_t cub_sort_temp_storage_bytes = 0;
321   auto cuda_stream = GetGpuStream(context);
322   auto device = context->eigen_gpu_device();
323   // Calling cub with nullptrs as inputs will make it return
324   // workspace size needed for the operation instead of doing the operation.
325   // In this specific instance, cub_sort_temp_storage_bytes will contain the
326   // necessary workspace size for sorting after the call.
327   if (num_boxes == 0) {
328     Tensor* output_indices = nullptr;
329     TF_RETURN_IF_ERROR(
330         context->allocate_output(0, TensorShape({0}), &output_indices));
331     return Status::OK();
332   }
333 
334   cudaError_t cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
335       nullptr, cub_sort_temp_storage_bytes,
336       static_cast<float*>(nullptr),  // scores
337       static_cast<float*>(nullptr),  // sorted scores
338       static_cast<int*>(nullptr),    // input indices
339       static_cast<int*>(nullptr),    // sorted indices
340       num_boxes,                     // num items
341       0, 8 * sizeof(float),          // sort all bits
342       cuda_stream);
343   TF_RETURN_IF_CUDA_ERROR(cuda_ret);
344   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
345 
346   Tensor d_cub_sort_buffer;
347   TF_RETURN_IF_ERROR(context->allocate_temp(
348       DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}),
349       &d_cub_sort_buffer));
350   Tensor d_indices;
351   TF_RETURN_IF_ERROR(context->allocate_temp(
352       DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
353   Tensor d_sorted_indices;
354   TF_RETURN_IF_ERROR(context->allocate_temp(
355       DataType::DT_INT32, TensorShape({num_boxes}), &d_sorted_indices));
356   Tensor d_selected_indices;
357   TF_RETURN_IF_ERROR(context->allocate_temp(
358       DataType::DT_INT32, TensorShape({num_boxes}), &d_selected_indices));
359   Tensor d_sorted_scores;
360   TF_RETURN_IF_ERROR(context->allocate_temp(
361       DataType::DT_FLOAT, TensorShape({num_boxes}), &d_sorted_scores));
362   Tensor d_sorted_boxes;
363   TF_RETURN_IF_ERROR(context->allocate_temp(
364       DataType::DT_FLOAT, TensorShape({num_boxes, 4}), &d_sorted_boxes));
365 
366   // this will return sorted scores and their indices
367   auto config = GetGpuLaunchConfig(num_boxes, device);
368   // initialize box and score indices
369   TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
370                               config.thread_per_block, 0, device.stream(),
371                               config.virtual_thread_count, 0,
372                               d_indices.flat<int>().data()));
373   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
374   cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
375       d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
376       scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
377       d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
378       num_boxes, 0,
379       8 * sizeof(float),  // sort all bits
380       cuda_stream);
381   TF_RETURN_IF_CUDA_ERROR(cuda_ret);
382 
383   // get pointers for easy access
384   const float4* original_boxes =
385       reinterpret_cast<const float4*>(boxes.flat<float>().data());
386   float4* sorted_boxes =
387       reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
388   const int* sorted_indices = d_sorted_indices.flat<int>().data();
389   // sort boxes using indices
390   TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, config.block_count,
391                               config.thread_per_block, 0, device.stream(),
392                               config.virtual_thread_count, sorted_indices,
393                               original_boxes, sorted_boxes));
394   int limited_num_boxes = num_boxes;
395   // filter boxes by scores if nms v3
396   if (score_threshold > std::numeric_limits<float>::lowest()) {
397     GreaterThanCubOp score_limit(score_threshold);
398     TF_ASSIGN_OR_RETURN(limited_num_boxes,
399                         CountIf(context, d_sorted_scores.flat<float>().data(),
400                                 score_limit, num_boxes));
401     if (limited_num_boxes == 0) {
402       Tensor* output_indices = nullptr;
403       VLOG(1) << "Number of boxes above score threshold " << score_threshold
404               << " is 0";
405       int len_output = pad_to_max_output ? max_output_size : 0;
406       *num_saved_outputs = 0;
407       TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({len_output}),
408                                                   &output_indices));
409       return Status::OK();
410     } else {
411       VLOG(2) << "Number of boxes above threshold=" << score_threshold << " is "
412               << limited_num_boxes;
413     }
414   }
415   int num_to_keep = 0;
416   // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
417   // flip boxes if necessary!
418   const bool flip_boxes = true;
419   auto status = NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes,
420                        iou_threshold_val, d_selected_indices.flat<int>().data(),
421                        &num_to_keep, context, max_output_size, flip_boxes);
422   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
423   if (!status.ok()) {
424     context->SetStatus(status);
425     return status;
426   }
427   Tensor* output_indices = nullptr;
428   int num_outputs = std::min(num_to_keep, (int)max_output_size);  // no padding!
429   if (pad_to_max_output && num_outputs != max_output_size) {
430     TF_RETURN_IF_ERROR(context->allocate_output(
431         0, TensorShape({max_output_size}), &output_indices));
432     config = GetGpuLaunchConfig(max_output_size, device);
433     TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count,
434                                 config.thread_per_block, 0, device.stream(),
435                                 config.virtual_thread_count,
436                                 output_indices->flat<int>().data()));
437 
438   } else {
439     TF_RETURN_IF_ERROR(context->allocate_output(0, TensorShape({num_outputs}),
440                                                 &output_indices));
441   }
442   if (num_outputs == 0) {
443     *num_saved_outputs = num_outputs;
444     return Status::OK();
445   }
446   config = GetGpuLaunchConfig(num_outputs, device);
447   TF_CHECK_OK(GpuLaunchKernel(
448       IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
449       0, device.stream(), config.virtual_thread_count,
450       d_selected_indices.flat<int>().data(), sorted_indices,
451       (*output_indices).flat<int>().data()));
452   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
453   *num_saved_outputs = num_outputs;
454   return Status::OK();
455 }
456 
457 // Extracts a scalar of type T from a tensor, with correct type checking.
458 // This is necessary because several of the kernels here assume
459 // T == T_threshold.
460 template <typename T>
GetScalar(const Tensor & tensor)461 T GetScalar(const Tensor& tensor) {
462   switch (tensor.dtype()) {
463     case DT_FLOAT:
464       return static_cast<T>(tensor.scalar<float>()());
465     case DT_DOUBLE:
466       return static_cast<T>(tensor.scalar<double>()());
467     case DT_BFLOAT16:
468       return static_cast<T>(tensor.scalar<Eigen::bfloat16>()());
469     case DT_HALF:
470       return static_cast<T>(tensor.scalar<Eigen::half>()());
471     default:
472       DCHECK(false) << "Unsupported type " << tensor.dtype();
473       break;
474   }
475   return static_cast<T>(0);
476 }
477 
CheckValidInputs(const Tensor & boxes,const Tensor & scores,const Tensor & max_output_size,const Tensor & iou_threshold)478 Status CheckValidInputs(const Tensor& boxes, const Tensor& scores,
479                         const Tensor& max_output_size,
480                         const Tensor& iou_threshold) {
481   if (!TensorShapeUtils::IsScalar(max_output_size.shape())) {
482     return errors::InvalidArgument("max_output_size must be 0-D, got shape ",
483                                    max_output_size.shape().DebugString(),
484                                    " (Shape must be rank 0 but is ", "rank ",
485                                    max_output_size.dims(), ")");
486   }
487   if (!TensorShapeUtils::IsScalar(iou_threshold.shape())) {
488     return errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
489                                    iou_threshold.shape().DebugString(),
490                                    " (Shape must be rank 0 but is rank ",
491                                    iou_threshold.dims(), ")");
492   }
493   const float iou_threshold_val = GetScalar<float>(iou_threshold);
494   if (iou_threshold_val < 0 || iou_threshold_val > 1) {
495     return errors::InvalidArgument("iou_threshold must be in [0, 1]");
496   }
497   if (boxes.dims() != 2) {
498     return errors::InvalidArgument(
499         "boxes must be a rank 2 tensor! (Shape must "
500         "be rank 2 but is rank ",
501         boxes.dims(), ")");
502   }
503   int num_boxes = boxes.dim_size(0);
504   if (boxes.dim_size(1) != 4) {
505     return errors::InvalidArgument(
506         "boxes must be Nx4 (Dimension must be 4 but"
507         " is ",
508         boxes.dim_size(1), ")");
509   }
510   if (scores.dims() != 1) {
511     return errors::InvalidArgument(
512         "scores must be a vector! (Shape must be "
513         "rank 1 but is rank ",
514         scores.dims(), ")");
515   }
516   if (scores.dim_size(0) != num_boxes) {
517     return errors::InvalidArgument(
518         "scores has incompatible shape "        // message must be exactly this
519         "(Dimensions must be equal, but are ",  // otherwise tests fail!
520         num_boxes, " and ", scores.dim_size(0), ")");
521   }
522   return Status::OK();
523 }
524 class NonMaxSuppressionV2GPUOp : public OpKernel {
525  public:
NonMaxSuppressionV2GPUOp(OpKernelConstruction * context)526   explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context)
527       : OpKernel(context) {}
528 
Compute(OpKernelContext * context)529   void Compute(OpKernelContext* context) override {
530     // boxes: [num_boxes, 4]
531     const Tensor& boxes = context->input(0);
532     // scores: [num_boxes]
533     const Tensor& scores = context->input(1);
534     // max_output_size: scalar
535     const Tensor& max_output_size = context->input(2);
536     // iou_threshold: scalar
537     const Tensor& iou_threshold = context->input(3);
538     auto valid =
539         CheckValidInputs(boxes, scores, max_output_size, iou_threshold);
540     if (!valid.ok()) {
541       context->SetStatus(valid);
542       return;
543     }
544     int num_boxes = boxes.dim_size(0);
545     if (num_boxes == 0) {
546       Tensor* output_indices = nullptr;
547       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
548                                                        &output_indices));
549       return;
550     }
551     const float iou_threshold_val = GetScalar<float>(iou_threshold);
552     const int64_t output_size = max_output_size.scalar<int>()();
553 
554     OP_REQUIRES_OK(
555         context,
556         DoNMS(context, boxes, scores, output_size, iou_threshold_val,
557               /*score_threshold is float lowest if score threshold is disabled*/
558               std::numeric_limits<float>::lowest(),
559               /*pad_to_max_output=*/false, &num_boxes));
560   }
561 };
562 
563 class NonMaxSuppressionV3GPUOp : public OpKernel {
564  public:
NonMaxSuppressionV3GPUOp(OpKernelConstruction * context)565   explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context)
566       : OpKernel(context) {}
567 
Compute(OpKernelContext * context)568   void Compute(OpKernelContext* context) override {
569     // boxes: [num_boxes, 4]
570     const Tensor& boxes = context->input(0);
571     // scores: [num_boxes]
572     const Tensor& scores = context->input(1);
573     // max_output_size: scalar
574     const Tensor& max_output_size = context->input(2);
575     // iou_threshold: scalar
576     const Tensor& iou_threshold = context->input(3);
577     auto valid =
578         CheckValidInputs(boxes, scores, max_output_size, iou_threshold);
579     if (!valid.ok()) {
580       context->SetStatus(valid);
581       return;
582     }
583 
584     const Tensor& score_threshold = context->input(4);
585     OP_REQUIRES(
586         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
587         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
588                                 score_threshold.shape().DebugString()));
589     const float score_threshold_val = GetScalar<float>(score_threshold);
590     int num_boxes = boxes.dim_size(0);
591     if (num_boxes == 0) {
592       Tensor* output_indices = nullptr;
593       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
594                                                        &output_indices));
595       return;
596     }
597     const float iou_threshold_val = GetScalar<float>(iou_threshold);
598     const int64_t output_size = max_output_size.scalar<int>()();
599     OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size,
600                                   iou_threshold_val, score_threshold_val,
601                                   /*pad_to_max_output=*/false, &num_boxes));
602   }
603 };
604 
605 class NonMaxSuppressionV4GPUOp : public OpKernel {
606  public:
NonMaxSuppressionV4GPUOp(OpKernelConstruction * context)607   explicit NonMaxSuppressionV4GPUOp(OpKernelConstruction* context)
608       : OpKernel(context) {
609     OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
610                                              &pad_to_max_output_size_));
611   }
612 
Compute(OpKernelContext * context)613   void Compute(OpKernelContext* context) override {
614     // boxes: [num_boxes, 4]
615     const Tensor& boxes = context->input(0);
616     // scores: [num_boxes]
617     const Tensor& scores = context->input(1);
618     // max_output_size: scalar
619     const Tensor& max_output_size = context->input(2);
620     // iou_threshold: scalar
621     const Tensor& iou_threshold = context->input(3);
622     auto valid =
623         CheckValidInputs(boxes, scores, max_output_size, iou_threshold);
624     if (!valid.ok()) {
625       context->SetStatus(valid);
626       return;
627     }
628 
629     const Tensor& score_threshold = context->input(4);
630     OP_REQUIRES(
631         context, TensorShapeUtils::IsScalar(score_threshold.shape()),
632         errors::InvalidArgument("score_threshold must be 0-D, got shape ",
633                                 score_threshold.shape().DebugString()));
634     const float score_threshold_val = GetScalar<float>(score_threshold);
635 
636     Tensor* num_outputs_t = nullptr;
637     OP_REQUIRES_OK(context,
638                    context->allocate_output(1, tensorflow::TensorShape({}),
639                                             &num_outputs_t));
640     auto device = context->eigen_gpu_device();
641     int num_boxes = boxes.dim_size(0);
642     if (num_boxes == 0) {
643       Tensor* output_indices = nullptr;
644       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
645                                                        &output_indices));
646       device.memcpy(num_outputs_t->flat<int>().data(), &num_boxes, sizeof(int));
647       return;
648     }
649 
650     const float iou_threshold_val = GetScalar<float>(iou_threshold);
651     const int64_t output_size = max_output_size.scalar<int>()();
652     int num_outputs = 0;
653     OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size,
654                                   iou_threshold_val, score_threshold_val,
655                                   pad_to_max_output_size_, &num_outputs));
656     device.memcpyHostToDevice(num_outputs_t->flat<int>().data(), &num_outputs,
657                               sizeof(int));
658     return;
659   }
660 
661  private:
662   bool pad_to_max_output_size_;
663 };
664 
665 }  // namespace
666 
NmsGpu(const float * d_sorted_boxes_float_ptr,const int num_boxes,const float iou_threshold,int * d_selected_indices,int * h_nkeep,OpKernelContext * context,const int max_boxes,bool flip_boxes)667 Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
668               const float iou_threshold, int* d_selected_indices, int* h_nkeep,
669               OpKernelContext* context, const int max_boxes, bool flip_boxes) {
670   // Making sure we respect the __align(16)__
671   // we promised to the compiler.
672   auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
673   if ((iptr & 15) != 0) {
674     return errors::InvalidArgument("Boxes should be aligned to 16 Bytes.");
675   }
676   // allocate bitmask arrays on host and on device
677   Tensor h_num_selected, d_nms_mask;
678   const int bit_mask_len =
679       (num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
680 
681   int64 max_nms_mask_size = num_boxes * bit_mask_len;
682   TF_RETURN_IF_ERROR(context->allocate_temp(
683       DataType::DT_INT32, TensorShape({max_nms_mask_size}), &d_nms_mask));
684   // reset data sensitive tensors
685   auto device = context->eigen_gpu_device();
686   auto config = GetGpuLaunchConfig(d_nms_mask.NumElements(), device);
687   TF_CHECK_OK(GpuLaunchKernel(SetZero<int>, config.block_count,
688                               config.thread_per_block, 0, device.stream(),
689                               config.virtual_thread_count,
690                               d_nms_mask.flat<int32>().data()));
691 
692   // h_num_selected is a host pinned tensor.  The GPU kernel can write to it
693   // directly, instead of writing to GPU memory and then copying down to
694   // num_selected, saving us a small D2H memcpy.  We've observed that even small
695   // D2H copies on the compute stream can have an outsized effect on latency.
696   AllocatorAttributes pinned_alloc_attrs;
697   pinned_alloc_attrs.set_on_host(true);
698   pinned_alloc_attrs.set_gpu_compatible(true);
699   TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
700                                             TensorShape({1}), &h_num_selected,
701                                             pinned_alloc_attrs));
702 
703   int* d_delete_mask = d_nms_mask.flat<int>().data();
704   int* h_selected_count = h_num_selected.flat<int>().data();
705   const Box* d_sorted_boxes =
706       reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr);
707   dim3 block_dim, thread_block;
708   int num_blocks = (num_boxes + kNmsBlockDim - 1) / kNmsBlockDim;
709   num_blocks = std::max(std::min(num_blocks, kNmsBlockDimMax), 1);
710   block_dim.x = num_blocks;
711   block_dim.y = num_blocks;
712   block_dim.z = 1;
713   thread_block.x = kNmsBlockDim;
714   thread_block.y = kNmsBlockDim;
715   thread_block.z = 1;
716   if (flip_boxes) {
717     TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true>, block_dim, thread_block, 0,
718                                 device.stream(), d_sorted_boxes, num_boxes,
719                                 iou_threshold, bit_mask_len, d_delete_mask));
720   } else {
721     TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false>, block_dim, thread_block, 0,
722                                 device.stream(), d_sorted_boxes, num_boxes,
723                                 iou_threshold, bit_mask_len, d_delete_mask));
724   }
725   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
726   // Overlapping CPU computes and D2H memcpy
727   // both take about the same time
728 
729   config = GetGpuLaunchConfig(num_boxes, device);
730   Tensor selected_boxes;
731   TF_RETURN_IF_ERROR(context->allocate_temp(
732       DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes));
733   Tensor d_indices;
734   TF_RETURN_IF_ERROR(context->allocate_temp(
735       DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
736   TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
737                               config.thread_per_block, 0, device.stream(),
738                               config.virtual_thread_count, 0,
739                               d_indices.flat<int>().data()));
740 
741   char* selected = (char*)(selected_boxes.flat<int8>().data());
742   TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int),
743                               device.stream(), d_delete_mask, bit_mask_len,
744                               num_boxes, max_boxes, selected));
745   TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
746   // do Cub::deviceSelect::flagged
747   size_t flagged_buffer_size = 0;
748   gpuprim::DeviceSelect::Flagged(static_cast<void*>(nullptr),  // temp_storage
749                                  flagged_buffer_size,
750                                  static_cast<int*>(nullptr),   // input
751                                  static_cast<char*>(nullptr),  // selection flag
752                                  static_cast<int*>(nullptr),   // selected items
753                                  static_cast<int*>(nullptr),   // num_selected
754                                  num_boxes, device.stream());
755   Tensor cub_scratch;
756   TF_RETURN_IF_ERROR(context->allocate_temp(
757       DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
758       &cub_scratch));
759   Tensor d_num_selected;
760   TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
761                                             TensorShape({1}), &d_num_selected));
762 
763   gpuprim::DeviceSelect::Flagged(
764       (void*)cub_scratch.flat<int8>().data(),  // temp_storage
765       flagged_buffer_size,
766       d_indices.flat<int>().data(),  // input
767       selected,                      // selection flag
768       d_selected_indices,            // selected items
769       h_selected_count, num_boxes, device.stream());
770   gpuEvent_t copy_done;
771   TF_RETURN_IF_CUDA_ERROR(
772       gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
773   TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
774   TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
775   gpuEventDestroy(copy_done);
776 
777   *h_nkeep = *h_selected_count;
778   return Status::OK();
779 }
780 
781 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
782                             .TypeConstraint<float>("T")
783                             .Device(DEVICE_GPU)
784                             .HostMemory("iou_threshold")
785                             .HostMemory("max_output_size"),
786                         NonMaxSuppressionV2GPUOp);
787 
788 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
789                             .TypeConstraint<float>("T")
790                             .Device(DEVICE_GPU)
791                             .HostMemory("iou_threshold")
792                             .HostMemory("max_output_size")
793                             .HostMemory("score_threshold"),
794                         NonMaxSuppressionV3GPUOp);
795 
796 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
797                             .TypeConstraint<float>("T")
798                             .Device(DEVICE_GPU)
799                             .HostMemory("iou_threshold")
800                             .HostMemory("max_output_size")
801                             .HostMemory("score_threshold"),
802                         NonMaxSuppressionV4GPUOp);
803 
804 }  // namespace tensorflow
805 #endif
806