xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/SparseStubs.h>
3 #include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
4 #include <ATen/native/cpu/Loops.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/AccumulateType.h>
7 
8 namespace at::native {
9 
10 namespace {
11 
12 template <typename func_t>
13 struct CPUKernelLauncher {
launchat::native::__anon83de43120111::CPUKernelLauncher14   static void launch(TensorIteratorBase& iter, const func_t& f) {
15     cpu_kernel(iter, f);
16   }
17 };
18 
19 struct MulOp {
20   template <typename scalar_t>
applyat::native::__anon83de43120111::MulOp21   static scalar_t apply(scalar_t a, scalar_t b) {
22     return a * b;
23   }
24 };
25 
26 template <>
apply(bool a,bool b)27 bool MulOp::apply(bool a, bool b) {
28   return a && b;
29 }
30 
31 struct RhsProjOp {
32   template <typename scalar_t>
applyat::native::__anon83de43120111::RhsProjOp33   static scalar_t apply(scalar_t a, scalar_t b) {
34     return b;
35   }
36 };
37 
38 struct LhsProjOp {
39   template <typename scalar_t>
applyat::native::__anon83de43120111::LhsProjOp40   static scalar_t apply(scalar_t a, scalar_t b) {
41     return a;
42   }
43 };
44 
45 template <typename binary_op_t>
46 struct CPUValueSelectionIntersectionKernel {
applyat::native::__anon83de43120111::CPUValueSelectionIntersectionKernel47   static Tensor apply(
48       const Tensor& lhs_values,
49       const Tensor& lhs_select_idx,
50       const Tensor& rhs_values,
51       const Tensor& rhs_select_idx,
52       const Tensor& intersection_counts,
53       const Tensor& argsort,
54       const bool accumulate_matches) {
55     auto iter = make_value_selection_intersection_iter(
56         lhs_values,
57         lhs_select_idx,
58         rhs_values,
59         rhs_select_idx,
60         intersection_counts);
61     auto res_values = iter.tensor(0);
62 
63     auto lhs_nnz_stride = lhs_values.stride(0);
64     auto rhs_nnz_stride = rhs_values.stride(0);
65 
66     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
67         ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, at::ScalarType::ComplexHalf,
68         res_values.scalar_type(),
69         "binary_op_intersection_cpu", [&] {
70             // COO indices are only 64-bit for now.
71             using index_t = int64_t;
72             auto loop = [&](char** data, const int64_t* strides, int64_t n) {
73               auto* ptr_res_values_bytes = data[0];
74               const auto* ptr_lhs_values_bytes = data[1];
75               const auto* ptr_lhs_select_idx_bytes = data[2];
76               const auto* ptr_rhs_values_bytes = data[3];
77               const auto* ptr_rhs_select_idx_bytes = data[4];
78               const auto* ptr_intersection_counts_bytes = data[5];
79               const auto* ptr_argsort = argsort.const_data_ptr<index_t>();
80 
81               for (int64_t i = 0; i < n; ++i) {
82                 // Exctract data
83                 auto* ptr_res_values = reinterpret_cast<scalar_t*>(ptr_res_values_bytes);
84                 const auto* ptr_lhs_values = reinterpret_cast<const scalar_t*>(ptr_lhs_values_bytes);
85                 const auto lhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_lhs_select_idx_bytes);
86                 const auto* ptr_rhs_values = reinterpret_cast<const scalar_t*>(ptr_rhs_values_bytes);
87                 const auto rhs_nnz_idx = *reinterpret_cast<const index_t*>(ptr_rhs_select_idx_bytes);
88                 const auto count = *reinterpret_cast<const int64_t*>(ptr_intersection_counts_bytes);
89 
90                 const auto* ptr_lhs_begin = ptr_lhs_values + lhs_nnz_idx * lhs_nnz_stride;
91                 const auto* ptr_rhs_sorted_nnz_idx = ptr_argsort + rhs_nnz_idx;
92 
93                 using accscalar_t = at::acc_type<scalar_t, /*is_gpu=*/false>;
94                 accscalar_t res_values = 0;
95                 accscalar_t lhs_values = static_cast<accscalar_t>(*ptr_lhs_begin);
96                 accscalar_t rhs_values;
97                 index_t rhs_sorted_nnz_idx;
98                 const auto match_count = accumulate_matches ? count : std::min<int64_t>(count, 1);
99                 for (int64_t c = 0; c < match_count; ++c) {
100                   rhs_sorted_nnz_idx = *ptr_rhs_sorted_nnz_idx++;
101                   rhs_values = static_cast<accscalar_t>(*(ptr_rhs_values + rhs_sorted_nnz_idx * rhs_nnz_stride));
102                   res_values += binary_op_t::apply(lhs_values, rhs_values);
103                 }
104                 *ptr_res_values = static_cast<scalar_t>(res_values);
105 
106                 // Advance
107                 ptr_res_values_bytes += strides[0];
108                 ptr_lhs_values_bytes += strides[1];
109                 ptr_lhs_select_idx_bytes += strides[2];
110                 ptr_rhs_values_bytes += strides[3];
111                 ptr_rhs_select_idx_bytes += strides[4];
112                 ptr_intersection_counts_bytes += strides[5];
113               }
114             };
115             iter.for_each(loop, at::internal::GRAIN_SIZE);
116         });
117 
118     return res_values;
119   }
120 };
121 
122 using OptTensor = std::optional<Tensor>;
123 
mul_sparse_sparse_out_cpu_kernel(Tensor & result,const Tensor & x,const Tensor & y)124 void mul_sparse_sparse_out_cpu_kernel(
125     Tensor& result,
126     const Tensor& x,
127     const Tensor& y) {
128   using CPUValueSelectionMulKernel = CPUValueSelectionIntersectionKernel<MulOp>;
129   _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueSelectionMulKernel>(
130       result, x, y
131   );
132 }
133 
sparse_mask_intersection_out_cpu_kernel(Tensor & result,const Tensor & x,const Tensor & y,const OptTensor & x_hash_opt=std::nullopt)134 void sparse_mask_intersection_out_cpu_kernel(
135     Tensor& result,
136     const Tensor& x,
137     const Tensor& y,
138     const OptTensor& x_hash_opt = std::nullopt) {
139   using CPUValueRhsProjKernel = CPUValueSelectionIntersectionKernel<RhsProjOp>;
140   _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueRhsProjKernel>(
141       result, x, y, x_hash_opt
142   );
143 }
144 
sparse_mask_projection_out_cpu_kernel(Tensor & result,const Tensor & x,const Tensor & y,const OptTensor & x_hash_opt,bool accumulate_matches)145 void sparse_mask_projection_out_cpu_kernel(
146     Tensor& result,
147     const Tensor& x,
148     const Tensor& y,
149     const OptTensor& x_hash_opt,
150     bool accumulate_matches) {
151   using CPUValueLhsProjKernel = CPUValueSelectionIntersectionKernel<LhsProjOp>;
152   _sparse_binary_op_intersection_kernel_out<CPUKernelLauncher, CPUValueLhsProjKernel>(
153       result, x, y, x_hash_opt, std::nullopt, accumulate_matches
154   );
155 }
156 
157 }
158 
159 REGISTER_ARCH_DISPATCH(mul_sparse_sparse_out_stub, DEFAULT, &mul_sparse_sparse_out_cpu_kernel);
160 REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
161 REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
162 REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
163 REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel);
164 
165 REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel);
166 REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
167 REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
168 REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
169 REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
170 
171 REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel);
172 REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
173 REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
174 REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
175 REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel);
176 }
177