xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SparseBinaryOpIntersectionKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/SparseStubs.h>
3 #include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/cuda/KernelUtils.cuh>
6 #include <ATen/cuda/detail/OffsetCalculator.cuh>
7 #include <ATen/AccumulateType.h>
8 
9 namespace at::native {
10 
11 namespace {
12 
13 template <typename func_t>
14 struct CUDAKernelLauncher {
launchat::native::__anonc592e6360111::CUDAKernelLauncher15   static void launch(TensorIteratorBase& iter, const func_t& f) {
16     gpu_kernel(iter, f);
17   }
18 };
19 
20 struct MulOp {
21   template <typename scalar_t>
applyat::native::__anonc592e6360111::MulOp22   static FUNCAPI INLINE scalar_t apply(scalar_t a, scalar_t b) {
23     return a * b;
24   }
25 };
26 
27 template <>
apply(bool a,bool b)28 FUNCAPI INLINE bool MulOp::apply(bool a, bool b) {
29   return a && b;
30 }
31 
32 struct RhsProjOp {
33   template <typename scalar_t>
applyat::native::__anonc592e6360111::RhsProjOp34   static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) {
35     return b;
36   }
37 };
38 
39 struct LhsProjOp {
40   template <typename scalar_t>
applyat::native::__anonc592e6360111::LhsProjOp41   static FUNCAPI scalar_t apply(scalar_t a, scalar_t b) {
42     return a;
43   }
44 };
45 
46 template <int nt, int vt, typename loop_t>
C10_LAUNCH_BOUNDS_2(nt,vt)47 C10_LAUNCH_BOUNDS_2(nt, vt)
48 __global__ void apply_kernel(int n, loop_t loop) {
49   constexpr int nv = nt * vt;
50   int idx = nv * blockIdx.x + threadIdx.x;
51 
52   #pragma unroll
53   for (int i = 0; i < vt; ++i) {
54     if (idx < n) {
55       loop(idx);
56       idx += nt;
57     }
58   }
59 }
60 
61 template <int nt, int vt, typename loop_t>
launch_kernel(int64_t n,const loop_t & loop)62 void launch_kernel(int64_t n, const loop_t& loop) {
63   TORCH_INTERNAL_ASSERT(0 <= n && n <= std::numeric_limits<int32_t>::max());
64   if (!n) {
65     return;
66   }
67 
68   const dim3 block(nt);
69   const dim3 grid((n + block.x * vt - 1) / (block.x * vt));
70   const auto stream = at::cuda::getCurrentCUDAStream();
71   apply_kernel<nt, vt, loop_t><<<grid, block, 0, stream>>>(n, loop);
72   C10_CUDA_KERNEL_LAUNCH_CHECK();
73 }
74 
75 template <typename binary_op_t, typename scalar_t, typename index_t>
binary_op_intersection_kernel(TensorIterator & iter,int64_t lhs_nnz_stride,int64_t rhs_nnz_stride,const Tensor & argsort,const bool accumulate_matches)76 void binary_op_intersection_kernel(
77     TensorIterator& iter,
78     int64_t lhs_nnz_stride,
79     int64_t rhs_nnz_stride,
80     const Tensor& argsort,
81     const bool accumulate_matches) {
82   if (!iter.can_use_32bit_indexing()) {
83     for (auto& sub_iter : iter.with_32bit_indexing()) {
84       binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
85           sub_iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
86     }
87     return;
88   }
89 
90   auto* RESTRICT ptr_res_values_bytes = reinterpret_cast<char*>(iter.data_ptr(0));
91   const auto* RESTRICT ptr_lhs_values_bytes = reinterpret_cast<char*>(iter.data_ptr(1));
92   const auto* RESTRICT ptr_lhs_select_idx_bytes = reinterpret_cast<char*>(iter.data_ptr(2));
93   const auto* RESTRICT ptr_rhs_values_bytes = reinterpret_cast<char*>(iter.data_ptr(3));
94   const auto* RESTRICT ptr_rhs_select_idx_bytes = reinterpret_cast<char*>(iter.data_ptr(4));
95   const auto* RESTRICT ptr_intersction_counts_bytes = reinterpret_cast<char*>(iter.data_ptr(5));
96   const auto* RESTRICT ptr_argsort = argsort.const_data_ptr<index_t>();
97 
98   auto offset_calc = make_offset_calculator<6>(iter);
99   auto loop = [=] FUNCAPI (int i) {
100     auto offsets = offset_calc.get(i);
101 
102     auto* RESTRICT ptr_res_values = reinterpret_cast<scalar_t*>(ptr_res_values_bytes + offsets[0]);
103     const auto* RESTRICT ptr_lhs_values = reinterpret_cast<const scalar_t*>(ptr_lhs_values_bytes + offsets[1]);
104     const auto lhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_lhs_select_idx_bytes + offsets[2]);
105     const auto* RESTRICT ptr_rhs_values = reinterpret_cast<const scalar_t*>(ptr_rhs_values_bytes + offsets[3]);
106     const auto rhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_rhs_select_idx_bytes + offsets[4]);
107     const auto count = *reinterpret_cast<const int64_t*>(ptr_intersction_counts_bytes + offsets[5]);
108 
109     const auto* RESTRICT ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride;
110     const auto* RESTRICT ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx;
111 
112     using accscalar_t = at::acc_type<scalar_t, /*is_gpu=*/true>;
113     accscalar_t res_values = 0;
114     accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
115     accscalar_t rhs_values;
116     index_t rhs_sorted_nnz_idx;
117     const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
118     for (int64_t c = 0; c < match_count; ++c) {
119       rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
120       rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
121       res_values += binary_op_t::apply(lhs_values, rhs_values);
122     }
123     *ptr_res_values = static_cast<scalar_t>(res_values);
124   };
125 
126   launch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
127 }
128 
129 
130 template <typename binary_op_t>
131 struct CUDAValueSelectionIntersectionKernel {
applyat::native::__anonc592e6360111::CUDAValueSelectionIntersectionKernel132   static Tensor apply(
133       const Tensor& lhs_values,
134       const Tensor& lhs_select_idx,
135       const Tensor& rhs_values,
136       const Tensor& rhs_select_idx,
137       const Tensor& intersection_counts,
138       const Tensor& argsort,
139       const bool accumulate_matches) {
140     auto iter = make_value_selection_intersection_iter(
141         lhs_values,
142         lhs_select_idx,
143         rhs_values,
144         rhs_select_idx,
145         intersection_counts);
146     auto res_values = iter.tensor(0);
147 
148     // If res_values is empty, we can return it right away.
149     // Otherwise floating point issues with OffsetCalculator.
150     if (!res_values.numel()) {
151       return res_values;
152     }
153 
154     const auto lhs_nnz_stride = lhs_values.stride(0);
155     const auto rhs_nnz_stride = rhs_values.stride(0);
156 
157     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
158         ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, at::ScalarType::ComplexHalf, res_values.scalar_type(),
159         "binary_op_intersection_cuda", [&] {
160           // COO indices are only 64-bit for now.
161           using index_t = int64_t;
162           binary_op_intersection_kernel<binary_op_t, scalar_t, index_t>(
163               iter, lhs_nnz_stride, rhs_nnz_stride, argsort, accumulate_matches);
164         });
165 
166     return res_values;
167   }
168 };
169 
170 using OptTensor = std::optional<Tensor>;
171 
mul_sparse_sparse_out_cuda_kernel(Tensor & result,const Tensor & x,const Tensor & y)172 void mul_sparse_sparse_out_cuda_kernel(
173     Tensor& result,
174     const Tensor& x,
175     const Tensor& y) {
176   using CUDAValueSelectionMulKernel = CUDAValueSelectionIntersectionKernel<MulOp>;
177   _sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueSelectionMulKernel>(
178       result, x, y
179   );
180 }
181 
sparse_mask_intersection_out_cuda_kernel(Tensor & result,const Tensor & x,const Tensor & y,const OptTensor & x_hash_opt=std::nullopt)182 void sparse_mask_intersection_out_cuda_kernel(
183     Tensor& result,
184     const Tensor& x,
185     const Tensor& y,
186     const OptTensor& x_hash_opt = std::nullopt) {
187   using CUDAValueRhsProjKernel = CUDAValueSelectionIntersectionKernel<RhsProjOp>;
188   _sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueRhsProjKernel>(
189       result, x, y, x_hash_opt
190   );
191 }
192 
sparse_mask_projection_out_cuda_kernel(Tensor & result,const Tensor & x,const Tensor & y,const OptTensor & x_hash_opt,bool accumulate_matches)193 void sparse_mask_projection_out_cuda_kernel(
194     Tensor& result,
195     const Tensor& x,
196     const Tensor& y,
197     const OptTensor& x_hash_opt,
198     bool accumulate_matches) {
199   using CUDAValueLhsProjKernel = CUDAValueSelectionIntersectionKernel<LhsProjOp>;
200   _sparse_binary_op_intersection_kernel_out<CUDAKernelLauncher, CUDAValueLhsProjKernel>(
201       result, x, y, x_hash_opt, std::nullopt, accumulate_matches
202   );
203 }
204 
205 }
206 
207 REGISTER_CUDA_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cuda_kernel);
208 REGISTER_CUDA_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cuda_kernel);
209 REGISTER_CUDA_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cuda_kernel);
210 
211 } // namespace at::native
212