xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/NumericUtils.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/ReduceAllOps.h>
6 #include <ATen/native/ReduceOps.h>
7 #include <ATen/native/SharedReduceOps.h>
8 #include <ATen/native/TensorCompare.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/ReduceOps.h>
11 #include <ATen/cuda/NumericLimits.cuh>
12 #include <ATen/native/cuda/Reduce.cuh>
13 
14 #include <ATen/Dispatch.h>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/cuda/NumericLimits.cuh>
17 
18 namespace at::native {
19 
20 template <typename scalar_t>
_min_max_values_kernel_cuda_impl(TensorIterator & iter)21 void _min_max_values_kernel_cuda_impl(TensorIterator& iter) {
22   gpu_reduce_kernel<scalar_t, scalar_t>(
23       iter,
24       MinMaxOps<scalar_t, scalar_t, int32_t>{},
25       thrust::pair<scalar_t, scalar_t>(
26           at::numeric_limits<scalar_t>::upper_bound(),
27           at::numeric_limits<scalar_t>::lower_bound()));
28 }
29 
aminmax_allreduce_launch_kernel(TensorIterator & iter)30 void aminmax_allreduce_launch_kernel(TensorIterator& iter) {
31   AT_DISPATCH_ALL_TYPES_AND3(
32       kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] {
33         _min_max_values_kernel_cuda_impl<scalar_t>(iter);
34       });
35 }
36 
aminmax_launch_kernel(TensorIterator & iter)37 void aminmax_launch_kernel(TensorIterator& iter) {
38   AT_DISPATCH_ALL_TYPES_AND3(
39       kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() {
40         gpu_reduce_kernel<scalar_t, scalar_t>(
41             iter,
42             MinMaxOps<scalar_t, scalar_t, int32_t>{},
43             thrust::pair<scalar_t, scalar_t>(
44                 at::numeric_limits<scalar_t>::upper_bound(),
45                 at::numeric_limits<scalar_t>::lower_bound()));
46       });
47 }
48 
49 } // namespace at::native
50