xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CrossKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Cross.h>
3 #include <ATen/cuda/detail/KernelUtils.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/Dispatch.h>
6 #include <ATen/core/Tensor.h>
7 
8 namespace at::native {
9 
10 template <typename T, typename OffsetCalc, typename StrideType>
cross_kernel(int numel,T * out,const T * x1,const T * x2,OffsetCalc offset_calculator,StrideType ostride,StrideType x1stride,StrideType x2stride)11 __global__ void cross_kernel(
12     int numel, T* out, const T* x1, const T* x2, OffsetCalc offset_calculator,
13     StrideType ostride, StrideType x1stride, StrideType x2stride) {
14   CUDA_KERNEL_LOOP(i, numel) {
15     const auto offsets = offset_calculator.get(i);
16     auto* out_row = out + offsets[0];
17     const auto* x1_row = x1 + offsets[1];
18     const auto* x2_row = x2 + offsets[2];
19 
20     const T val0 = (x1_row[1 * x1stride] * x2_row[2 * x2stride] -
21                     x1_row[2 * x1stride] * x2_row[1 * x2stride]);
22 
23     const T val1 = (x1_row[2 * x1stride] * x2_row[0 * x2stride] -
24                     x1_row[0 * x1stride] * x2_row[2 * x2stride]);
25 
26     const T val2 = (x1_row[0 * x1stride] * x2_row[1 * x2stride] -
27                     x1_row[1 * x1stride] * x2_row[0 * x2stride]);
28 
29 
30     out_row[0 * ostride] = val0;
31     out_row[1 * ostride] = val1;
32     out_row[2 * ostride] = val2;
33   }
34 }
35 
launch_cross_kernel(const TensorIteratorBase & iter,int64_t ostride,int64_t x1stride,int64_t x2stride)36 void launch_cross_kernel(const TensorIteratorBase& iter, int64_t ostride,
37                          int64_t x1stride, int64_t x2stride) {
38   const auto N = iter.numel();
39   auto offset_calculator = make_element_offset_calculator<3>(iter);
40   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits<int32_t>::max());
41   int64_t grid = (N + num_threads() - 1) / num_threads();
42   auto stream = at::cuda::getCurrentCUDAStream();
43 
44   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "cross_cuda", [&] {
45     auto out = static_cast<scalar_t*>(iter.data_ptr(0));
46     auto x1 = static_cast<const scalar_t*>(iter.data_ptr(1));
47     auto x2 = static_cast<const scalar_t*>(iter.data_ptr(2));
48     constexpr int64_t int_max = std::numeric_limits<int>::max();
49     if (ostride * 2 > int_max || x1stride * 2 > int_max || x2stride * 2 > int_max) {
50       cross_kernel<<<grid, num_threads(), 0, stream>>>(
51           N, out, x1, x2, offset_calculator, ostride, x1stride, x2stride);
52       C10_CUDA_KERNEL_LAUNCH_CHECK();
53     } else {
54       cross_kernel<<<grid, num_threads(), 0, stream>>>(
55           N, out, x1, x2, offset_calculator,
56           static_cast<int>(ostride),
57           static_cast<int>(x1stride),
58           static_cast<int>(x2stride));
59       C10_CUDA_KERNEL_LAUNCH_CHECK();
60     }
61   });
62 }
63 
cross_impl(const Tensor & result,const Tensor & x1,const Tensor & x2,int64_t dim)64 void cross_impl(const Tensor& result, const Tensor& x1, const Tensor& x2, int64_t dim) {
65   const int64_t ostride = result.stride(dim);
66   const int64_t x1stride = x1.stride(dim);
67   const int64_t x2stride = x2.stride(dim);
68 
69   auto iter = TensorIteratorConfig()
70       .add_output(result)
71       .add_const_input(x1)
72       .add_const_input(x2)
73       .resize_outputs(false)
74       .declare_static_shape(result.sizes(), /*squash_dims=*/dim)
75       .build();
76 
77   if (iter.numel() == 0) {
78     return;
79   }
80 
81   if (iter.can_use_32bit_indexing()) {
82     launch_cross_kernel(iter, ostride, x1stride, x2stride);
83   } else {
84     for (auto&& sub_iter: iter.with_32bit_indexing()) {
85       launch_cross_kernel(sub_iter, ostride, x1stride, x2stride);
86     }
87   }
88 }
89 
90 REGISTER_DISPATCH(cross_stub, &cross_impl);
91 
92 } // namespace at::native
93