xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/SampledAddmmKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/cpu/SampledAddmmKernel.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 
12 namespace at::native {
13 
14 namespace {
15 
16 template <typename scalar_t, typename index_t>
sampled_addmm_sparse_csr_kernel_impl(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)17 void sampled_addmm_sparse_csr_kernel_impl(
18     const Tensor& mat1,
19     const Tensor& mat2,
20     const Scalar& beta,
21     const Scalar& alpha,
22     const Tensor& result) {
23 
24   int64_t nnz = result._nnz();
25 
26   auto beta_ = beta.to<scalar_t>();
27   auto alpha_ = alpha.to<scalar_t>();
28 
29   const scalar_t* mat1_data = mat1.const_data_ptr<scalar_t>();
30   const scalar_t* mat2_data = mat2.const_data_ptr<scalar_t>();
31 
32   // mat1: {B, M, K}
33   // mat2: {B, N, K}
34   // crow: {B, M + 1}
35   // col, values: {B, nnz}
36   int64_t M = mat1.size(-2);
37   int64_t K = mat1.size(-1);
38   int64_t N = mat2.size(-2);
39   int64_t B = mat1.numel() / M / K;
40 
41   auto values = result.values().reshape({-1, nnz});
42   auto crow = result.crow_indices().reshape({-1, M + 1});
43   auto col = result.col_indices().reshape({-1, nnz});
44 
45   auto values_acc = values.accessor<scalar_t, 2>();
46   auto crow_acc = crow.accessor<const index_t, 2>();
47   auto col_acc = col.accessor<const index_t, 2>();
48 
49   // usually, collapse B and M is a better option,
50   // but for most commonly used case (mat1 and mat2 is 2d tensor), B = 1,
51   // balance partition M by using parallel_sparse_csr.
52   using Vec = vec::Vectorized<scalar_t>;
53   for (const auto b : c10::irange(B)) {
54     auto crow_slice = crow_acc[b];
55     auto col_slice = col_acc[b];
56     auto values_slice = values_acc[b];
57     const scalar_t* mat1_ptr = mat1_data + b * M * K;
58     const scalar_t* mat2_ptr = mat2_data + b * N * K;
59 
60     utils::parallel_sparse_csr(crow_slice, M, nnz, [&](int64_t begin, int64_t end) {
61       for (const auto m : c10::irange(begin, end)) {
62         int64_t row_start = crow_slice[m];
63         int64_t row_end = crow_slice[m + 1];
64         for (const auto e : c10::irange(row_start, row_end)) {
65           int64_t n = col_slice[e];
66           scalar_t val = values_slice[e];
67           scalar_t dot = vec::map2_reduce_all<scalar_t>(
68               [](Vec x, Vec y) { return x * y; },
69               [](Vec x, Vec y) { return x + y; },
70               mat1_ptr + m * K,
71               mat2_ptr + n * K,
72               K);
73           val = alpha_ * dot + beta_ * val;
74           values_slice[e] = val;
75         }
76       }
77     });
78   }
79 }
80 
sampled_addmm_sparse_csr_kernel(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)81 void sampled_addmm_sparse_csr_kernel(
82     const Tensor& mat1,
83     const Tensor& mat2,
84     const Scalar& beta,
85     const Scalar& alpha,
86     const Tensor& result) {
87   const auto index_type = result.crow_indices().scalar_type();
88   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1.scalar_type(), "sampled_addmm_sparse_csr_kernel", [&]() {
89     AT_DISPATCH_INDEX_TYPES(index_type, "sampled_addmm_sparse_csr_index", [&]() {
90       sampled_addmm_sparse_csr_kernel_impl<scalar_t, index_t>(mat1, mat2, beta, alpha, result);
91     });
92   });
93 }
94 
95 } // anonymous namespace
96 
97 REGISTER_DISPATCH(sampled_addmm_sparse_csr_stub, &sampled_addmm_sparse_csr_kernel);
98 
99 } // at::native
100