xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ScalarOps.h>
2 #include <ATen/Tensor.h>
3 #include <ATen/Functions.h>
4 #include <ATen/Utils.h>
5 #include <c10/cuda/CUDAGuard.h>
6 #include <c10/util/accumulate.h>
7 
8 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
9 #else
10 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
11 #endif
12 
13 namespace at::native {
14 
15 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
16 #else
17 template <typename KT>
18 __global__ void __launch_bounds__(32 /* num_threads */)
19   sparse_semi_structured_apply_kernel(typename KT::Params p)
20 {
21   KT::sparse_semi_structured_apply_kernel(p);
22 }
23 
24 // Apply a 2:4 sparsify pattern computed with
25 // `_sparse_semi_structured_tile` to another Tensor
26 template <bool kIsMeta, typename Element>
27 std::tuple<Tensor, Tensor> _sparse_semi_structured_apply_typed(Tensor input, Tensor threads_masks)
28 {
29   using KT = KernelTypes<Element>;
30   // TODO: Technically we should be able to deal with that
31   // by running on the transpose of `input` and swapping
32   // `packed` & `packed_t`.
33   // This would require to adapt the `threads_masks` a bit tho.
34   if (input.stride(1) != 1) {
35     input = input.contiguous();
36   }
37   std::optional<at::cuda::CUDAGuard> device_guard;
38   if (!kIsMeta) {
39     device_guard.emplace(input.device());
40   }
41 
42   TORCH_CHECK(input.dim() == 2);
43   TORCH_CHECK(input.stride(1) == 1);
44   TORCH_CHECK(input.stride(0) % 8 == 0);
45   TORCH_CHECK(input.size(1) % 32 == 0, "Wrong alignment shape[1]");
46 
47   auto roundedx = cutlass::round_up(input.size(0), kWarpX);
48   auto roundedy = cutlass::round_up(input.size(1), kWarpY);
49   at::Tensor packed =
50       at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, input.options());
51   at::Tensor packed_trans =
52       at::empty({roundedy, cutlass::ceil_div(roundedx, 2)}, input.options());
53 
54   typename KT::Params p;
55   p.input = (Element const*)input.data_ptr();
56   p.input_s0 = input.stride(0);
57   p.input_dim0 = input.size(0);
58   p.input_dim1 = input.size(1);
59 
60   p.packed = (Element*)packed.data_ptr();
61   p.packed_stride = packed.stride(0);
62   p.packed_trans = (Element*)packed_trans.data_ptr();
63   p.packed_trans_stride = packed_trans.stride(0);
64 
65   p.threads_masks = (uint64_t*)threads_masks.data_ptr();
66 
67   TORCH_CHECK(threads_masks.dim() == 3);
68   TORCH_CHECK(
69       threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
70   TORCH_CHECK(
71       threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
72   TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
73   TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
74   TORCH_CHECK(threads_masks.stride(2) == 1);
75   TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
76 
77   if (!kIsMeta) {
78     size_t smem_bytes = 0;
79     sparse_semi_structured_apply_kernel<KT>
80         <<<p.getBlocksGrid(),
81            p.getThreadsGrid(),
82            smem_bytes,
83            at::cuda::getCurrentCUDAStream()>>>(p);
84     C10_CUDA_KERNEL_LAUNCH_CHECK();
85   }
86   return std::make_tuple(packed, packed_trans);
87 }
88 #endif
89 
_sparse_semi_structured_apply(const Tensor & input,const Tensor & threads_masks)90 std::tuple<Tensor, Tensor> _sparse_semi_structured_apply(const Tensor& input, const Tensor& threads_masks) // Returned by `_sparse_semi_structured_tile`
91 {
92 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
93   AT_ERROR("_sparse_semi_structured_apply: not supported");
94   return std::make_tuple(Tensor{}, Tensor{});
95 #else
96   TORCH_CHECK(
97     input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16,
98     "Unsupported dtype - only `float16` and `bfloat16` are supported currently"
99   );
100   auto result = (input.scalar_type() == at::ScalarType::Half)
101             ? _sparse_semi_structured_apply_typed<false, cutlass::half_t>(input, threads_masks)
102             : _sparse_semi_structured_apply_typed<false, cutlass::bfloat16_t>(input, threads_masks);
103   return result;
104 #endif
105 }
106 
107 } // namespace
108