1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Cross.h>
3 #include <ATen/cuda/detail/KernelUtils.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/Dispatch.h>
6 #include <ATen/core/Tensor.h>
7
8 namespace at::native {
9
10 template <typename T, typename OffsetCalc, typename StrideType>
cross_kernel(int numel,T * out,const T * x1,const T * x2,OffsetCalc offset_calculator,StrideType ostride,StrideType x1stride,StrideType x2stride)11 __global__ void cross_kernel(
12 int numel, T* out, const T* x1, const T* x2, OffsetCalc offset_calculator,
13 StrideType ostride, StrideType x1stride, StrideType x2stride) {
14 CUDA_KERNEL_LOOP(i, numel) {
15 const auto offsets = offset_calculator.get(i);
16 auto* out_row = out + offsets[0];
17 const auto* x1_row = x1 + offsets[1];
18 const auto* x2_row = x2 + offsets[2];
19
20 const T val0 = (x1_row[1 * x1stride] * x2_row[2 * x2stride] -
21 x1_row[2 * x1stride] * x2_row[1 * x2stride]);
22
23 const T val1 = (x1_row[2 * x1stride] * x2_row[0 * x2stride] -
24 x1_row[0 * x1stride] * x2_row[2 * x2stride]);
25
26 const T val2 = (x1_row[0 * x1stride] * x2_row[1 * x2stride] -
27 x1_row[1 * x1stride] * x2_row[0 * x2stride]);
28
29
30 out_row[0 * ostride] = val0;
31 out_row[1 * ostride] = val1;
32 out_row[2 * ostride] = val2;
33 }
34 }
35
launch_cross_kernel(const TensorIteratorBase & iter,int64_t ostride,int64_t x1stride,int64_t x2stride)36 void launch_cross_kernel(const TensorIteratorBase& iter, int64_t ostride,
37 int64_t x1stride, int64_t x2stride) {
38 const auto N = iter.numel();
39 auto offset_calculator = make_element_offset_calculator<3>(iter);
40 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits<int32_t>::max());
41 int64_t grid = (N + num_threads() - 1) / num_threads();
42 auto stream = at::cuda::getCurrentCUDAStream();
43
44 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "cross_cuda", [&] {
45 auto out = static_cast<scalar_t*>(iter.data_ptr(0));
46 auto x1 = static_cast<const scalar_t*>(iter.data_ptr(1));
47 auto x2 = static_cast<const scalar_t*>(iter.data_ptr(2));
48 constexpr int64_t int_max = std::numeric_limits<int>::max();
49 if (ostride * 2 > int_max || x1stride * 2 > int_max || x2stride * 2 > int_max) {
50 cross_kernel<<<grid, num_threads(), 0, stream>>>(
51 N, out, x1, x2, offset_calculator, ostride, x1stride, x2stride);
52 C10_CUDA_KERNEL_LAUNCH_CHECK();
53 } else {
54 cross_kernel<<<grid, num_threads(), 0, stream>>>(
55 N, out, x1, x2, offset_calculator,
56 static_cast<int>(ostride),
57 static_cast<int>(x1stride),
58 static_cast<int>(x2stride));
59 C10_CUDA_KERNEL_LAUNCH_CHECK();
60 }
61 });
62 }
63
cross_impl(const Tensor & result,const Tensor & x1,const Tensor & x2,int64_t dim)64 void cross_impl(const Tensor& result, const Tensor& x1, const Tensor& x2, int64_t dim) {
65 const int64_t ostride = result.stride(dim);
66 const int64_t x1stride = x1.stride(dim);
67 const int64_t x2stride = x2.stride(dim);
68
69 auto iter = TensorIteratorConfig()
70 .add_output(result)
71 .add_const_input(x1)
72 .add_const_input(x2)
73 .resize_outputs(false)
74 .declare_static_shape(result.sizes(), /*squash_dims=*/dim)
75 .build();
76
77 if (iter.numel() == 0) {
78 return;
79 }
80
81 if (iter.can_use_32bit_indexing()) {
82 launch_cross_kernel(iter, ostride, x1stride, x2stride);
83 } else {
84 for (auto&& sub_iter: iter.with_32bit_indexing()) {
85 launch_cross_kernel(sub_iter, ostride, x1stride, x2stride);
86 }
87 }
88 }
89
90 REGISTER_DISPATCH(cross_stub, &cross_impl);
91
92 } // namespace at::native
93