xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/NumericUtils.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/ReduceAllOps.h>
6 #include <ATen/native/ReduceOps.h>
7 #include <ATen/native/SharedReduceOps.h>
8 #include <ATen/native/TensorCompare.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/ReduceOps.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <ATen/native/cuda/Reduce.cuh>
13 
14 #include <ATen/Dispatch.h>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/cuda/NumericLimits.cuh>
17 
18 namespace at::native {
19 
20 template <typename scalar_t, typename acc_t = scalar_t>
argmin_kernel_cuda_impl(TensorIterator & iter)21 void argmin_kernel_cuda_impl(TensorIterator& iter) {
22   gpu_reduce_kernel<scalar_t, int64_t>(
23       iter,
24       ArgMinOps<acc_t>{},
25       thrust::pair<acc_t, int64_t>(
26           at::numeric_limits<acc_t>::upper_bound(), 0));
27 };
28 
argmin_kernel_cuda(TensorIterator & iter)29 void argmin_kernel_cuda(TensorIterator& iter) {
30   // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down,
31   // we can convert float16 & bfloat16 to float and do all the operations in
32   // float.
33   if (iter.dtype(1) == kHalf) {
34     argmin_kernel_cuda_impl<at::Half, float>(iter);
35   } else if (iter.dtype(1) == kBFloat16) {
36     argmin_kernel_cuda_impl<at::BFloat16, float>(iter);
37   } else {
38     AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() {
39       argmin_kernel_cuda_impl<scalar_t>(iter);
40     });
41   }
42 }
43 
44 REGISTER_DISPATCH(argmin_stub, &argmin_kernel_cuda);
45 
46 } // namespace at::native
47