1 #include <ATen/ScalarOps.h>
2 #include <ATen/Tensor.h>
3 #include <ATen/Functions.h>
4 #include <ATen/Utils.h>
5 #include <c10/cuda/CUDAGuard.h>
6 #include <c10/util/accumulate.h>
7
8 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
9 #else
10 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
11 #endif
12
13 namespace at::native {
14
15 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
16 #else
17 template <typename KT>
18 __global__ void __launch_bounds__(32 /* num_threads */)
19 sparse_semi_structured_apply_kernel(typename KT::Params p)
20 {
21 KT::sparse_semi_structured_apply_kernel(p);
22 }
23
24 // Apply a 2:4 sparsify pattern computed with
25 // `_sparse_semi_structured_tile` to another Tensor
26 template <bool kIsMeta, typename Element>
27 std::tuple<Tensor, Tensor> _sparse_semi_structured_apply_typed(Tensor input, Tensor threads_masks)
28 {
29 using KT = KernelTypes<Element>;
30 // TODO: Technically we should be able to deal with that
31 // by running on the transpose of `input` and swapping
32 // `packed` & `packed_t`.
33 // This would require to adapt the `threads_masks` a bit tho.
34 if (input.stride(1) != 1) {
35 input = input.contiguous();
36 }
37 std::optional<at::cuda::CUDAGuard> device_guard;
38 if (!kIsMeta) {
39 device_guard.emplace(input.device());
40 }
41
42 TORCH_CHECK(input.dim() == 2);
43 TORCH_CHECK(input.stride(1) == 1);
44 TORCH_CHECK(input.stride(0) % 8 == 0);
45 TORCH_CHECK(input.size(1) % 32 == 0, "Wrong alignment shape[1]");
46
47 auto roundedx = cutlass::round_up(input.size(0), kWarpX);
48 auto roundedy = cutlass::round_up(input.size(1), kWarpY);
49 at::Tensor packed =
50 at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, input.options());
51 at::Tensor packed_trans =
52 at::empty({roundedy, cutlass::ceil_div(roundedx, 2)}, input.options());
53
54 typename KT::Params p;
55 p.input = (Element const*)input.data_ptr();
56 p.input_s0 = input.stride(0);
57 p.input_dim0 = input.size(0);
58 p.input_dim1 = input.size(1);
59
60 p.packed = (Element*)packed.data_ptr();
61 p.packed_stride = packed.stride(0);
62 p.packed_trans = (Element*)packed_trans.data_ptr();
63 p.packed_trans_stride = packed_trans.stride(0);
64
65 p.threads_masks = (uint64_t*)threads_masks.data_ptr();
66
67 TORCH_CHECK(threads_masks.dim() == 3);
68 TORCH_CHECK(
69 threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
70 TORCH_CHECK(
71 threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
72 TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
73 TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
74 TORCH_CHECK(threads_masks.stride(2) == 1);
75 TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
76
77 if (!kIsMeta) {
78 size_t smem_bytes = 0;
79 sparse_semi_structured_apply_kernel<KT>
80 <<<p.getBlocksGrid(),
81 p.getThreadsGrid(),
82 smem_bytes,
83 at::cuda::getCurrentCUDAStream()>>>(p);
84 C10_CUDA_KERNEL_LAUNCH_CHECK();
85 }
86 return std::make_tuple(packed, packed_trans);
87 }
88 #endif
89
_sparse_semi_structured_apply(const Tensor & input,const Tensor & threads_masks)90 std::tuple<Tensor, Tensor> _sparse_semi_structured_apply(const Tensor& input, const Tensor& threads_masks) // Returned by `_sparse_semi_structured_tile`
91 {
92 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
93 AT_ERROR("_sparse_semi_structured_apply: not supported");
94 return std::make_tuple(Tensor{}, Tensor{});
95 #else
96 TORCH_CHECK(
97 input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16,
98 "Unsupported dtype - only `float16` and `bfloat16` are supported currently"
99 );
100 auto result = (input.scalar_type() == at::ScalarType::Half)
101 ? _sparse_semi_structured_apply_typed<false, cutlass::half_t>(input, threads_masks)
102 : _sparse_semi_structured_apply_typed<false, cutlass::bfloat16_t>(input, threads_masks);
103 return result;
104 #endif
105 }
106
107 } // namespace
108