1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/cpu/SampledAddmmKernel.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11
12 namespace at::native {
13
14 namespace {
15
16 template <typename scalar_t, typename index_t>
sampled_addmm_sparse_csr_kernel_impl(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)17 void sampled_addmm_sparse_csr_kernel_impl(
18 const Tensor& mat1,
19 const Tensor& mat2,
20 const Scalar& beta,
21 const Scalar& alpha,
22 const Tensor& result) {
23
24 int64_t nnz = result._nnz();
25
26 auto beta_ = beta.to<scalar_t>();
27 auto alpha_ = alpha.to<scalar_t>();
28
29 const scalar_t* mat1_data = mat1.const_data_ptr<scalar_t>();
30 const scalar_t* mat2_data = mat2.const_data_ptr<scalar_t>();
31
32 // mat1: {B, M, K}
33 // mat2: {B, N, K}
34 // crow: {B, M + 1}
35 // col, values: {B, nnz}
36 int64_t M = mat1.size(-2);
37 int64_t K = mat1.size(-1);
38 int64_t N = mat2.size(-2);
39 int64_t B = mat1.numel() / M / K;
40
41 auto values = result.values().reshape({-1, nnz});
42 auto crow = result.crow_indices().reshape({-1, M + 1});
43 auto col = result.col_indices().reshape({-1, nnz});
44
45 auto values_acc = values.accessor<scalar_t, 2>();
46 auto crow_acc = crow.accessor<const index_t, 2>();
47 auto col_acc = col.accessor<const index_t, 2>();
48
49 // usually, collapse B and M is a better option,
50 // but for most commonly used case (mat1 and mat2 is 2d tensor), B = 1,
51 // balance partition M by using parallel_sparse_csr.
52 using Vec = vec::Vectorized<scalar_t>;
53 for (const auto b : c10::irange(B)) {
54 auto crow_slice = crow_acc[b];
55 auto col_slice = col_acc[b];
56 auto values_slice = values_acc[b];
57 const scalar_t* mat1_ptr = mat1_data + b * M * K;
58 const scalar_t* mat2_ptr = mat2_data + b * N * K;
59
60 utils::parallel_sparse_csr(crow_slice, M, nnz, [&](int64_t begin, int64_t end) {
61 for (const auto m : c10::irange(begin, end)) {
62 int64_t row_start = crow_slice[m];
63 int64_t row_end = crow_slice[m + 1];
64 for (const auto e : c10::irange(row_start, row_end)) {
65 int64_t n = col_slice[e];
66 scalar_t val = values_slice[e];
67 scalar_t dot = vec::map2_reduce_all<scalar_t>(
68 [](Vec x, Vec y) { return x * y; },
69 [](Vec x, Vec y) { return x + y; },
70 mat1_ptr + m * K,
71 mat2_ptr + n * K,
72 K);
73 val = alpha_ * dot + beta_ * val;
74 values_slice[e] = val;
75 }
76 }
77 });
78 }
79 }
80
sampled_addmm_sparse_csr_kernel(const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,const Tensor & result)81 void sampled_addmm_sparse_csr_kernel(
82 const Tensor& mat1,
83 const Tensor& mat2,
84 const Scalar& beta,
85 const Scalar& alpha,
86 const Tensor& result) {
87 const auto index_type = result.crow_indices().scalar_type();
88 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1.scalar_type(), "sampled_addmm_sparse_csr_kernel", [&]() {
89 AT_DISPATCH_INDEX_TYPES(index_type, "sampled_addmm_sparse_csr_index", [&]() {
90 sampled_addmm_sparse_csr_kernel_impl<scalar_t, index_t>(mat1, mat2, beta, alpha, result);
91 });
92 });
93 }
94
95 } // anonymous namespace
96
97 REGISTER_DISPATCH(sampled_addmm_sparse_csr_stub, &sampled_addmm_sparse_csr_kernel);
98
99 } // at::native
100