1 #include <ATen/ScalarOps.h>
2 #include <ATen/Tensor.h>
3 #include <ATen/Functions.h>
4 #include <ATen/autocast_mode.h>
5 #include <c10/cuda/CUDAGuard.h>
6
7 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
8 #else
9 #include <ATen/native/sparse/cuda/ComputeSparseTile.h>
10 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
11 #endif
12
13 namespace at::native {
14
15 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
16 #else
17 struct Params {
18 uint64_t const* threads_masks;
19
20 uint16_t const* input;
21 int64_t input_stride;
22 int64_t input_dim0;
23 int64_t input_dim1;
24
25 uint16_t* output;
26 int64_t output_stride;
27
28 __host__ dim3 getBlocksGrid() const {
29 return dim3(
30 cutlass::ceil_div(input_dim0, kWarpX),
31 cutlass::ceil_div(input_dim1, kWarpY),
32 1);
33 }
34
35 static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
36 return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
37 }
38
39 CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
40 Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
41 gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
42 int64_t strideX = gridDim.y * getThreadsGrid().y;
43 gmem_threads_masks +=
44 (blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
45 return gmem_threads_masks;
46 }
47 };
48
49 template <bool kInputRowMajor = true, bool kOutputRowMajor = true>
50 __global__ void __launch_bounds__(32 /* num_threads */, 32) sparse_semi_structured_apply_dense_k(Params p) {
51 using Fragment = cutlass::Array<uint16_t, 8>;
52
53 // Top-left of the 8x8 tile we own
54 int warp_x = blockIdx.x * kWarpX;
55 int warp_y = blockIdx.y * kWarpY;
56 int x = warp_x + threadIdx.x * kThreadX;
57 int y = warp_y + threadIdx.y * kThreadY;
58
59 uint16_t* output = p.output + x * p.output_stride + y;
60 Tile8x8Masks indices = *p.getCurrentThreadIndices();
61
62 // Load dense
63 Fragment lines[8];
64 if (kInputRowMajor) {
65 uint16_t const* input = p.input + x * p.input_stride + y;
66 CUTLASS_PRAGMA_UNROLL
67 for (int i = 0; i < 8; ++i) {
68 cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
69 lines[i], input + i * p.input_stride, true);
70 }
71 } else {
72 uint16_t const* input = p.input + x + y * p.input_stride;
73 Fragment columns[8];
74 CUTLASS_PRAGMA_UNROLL
75 for (int i = 0; i < 8; ++i) {
76 cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
77 columns[i], input + i * p.input_stride, true);
78 }
79 CUTLASS_PRAGMA_UNROLL
80 for (int i = 0; i < 8; ++i) {
81 CUTLASS_PRAGMA_UNROLL
82 for (int j = 0; j < 8; ++j) {
83 lines[i][j] = columns[j][i].get();
84 }
85 }
86 }
87
88 CUTLASS_PRAGMA_UNROLL
89 for (int row = 0; row < 2; ++row) {
90 Indices4x4 masks[2];
91 if (row == 0) {
92 masks[0] = indices.a;
93 masks[1] = indices.b;
94 } else {
95 masks[0] = indices.c;
96 masks[1] = indices.d;
97 }
98
99 // Apply mask
100 CUTLASS_PRAGMA_UNROLL
101 for (int m = 0; m < 2; ++m) {
102 CUTLASS_PRAGMA_UNROLL
103 for (int r = 0; r < 4; ++r) {
104 CUTLASS_PRAGMA_UNROLL
105 for (int c = 0; c < 4; ++c) {
106 lines[4 * row + r][4 * m + c] = lines[4 * row + r][4 * m + c] *
107 int((masks[m] >> (4 * r + c)) & 1);
108 }
109 }
110 }
111 }
112 static_assert(kOutputRowMajor, "Transpose here for ColMajor output");
113 // Save dense with zeros
114 CUTLASS_PRAGMA_UNROLL
115 for (int i = 0; i < 8; ++i) {
116 cutlass::arch::global_store<Fragment, sizeof(Fragment)>(
117 lines[i], output + i * p.output_stride, true);
118 }
119 }
120 #endif
121
_sparse_semi_structured_apply_dense(const Tensor & input,const Tensor & threads_masks)122 Tensor _sparse_semi_structured_apply_dense(
123 const Tensor& input,
124 const Tensor& threads_masks) {
125
126 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
127 AT_ERROR("_sparse_semi_structured_apply_dense: not supported");
128 return Tensor{};
129 #else
130 TORCH_CHECK(
131 input.scalar_type() == at::ScalarType::Half ||
132 input.scalar_type() == at::ScalarType::BFloat16,
133 "Unsupported `input` dtype");
134 TORCH_CHECK(
135 input.stride(0) == 1 || input.stride(1) == 1,
136 "`input` should be either RowMajor or ColMajor. Invalid memory layout - try .contiguous()?");
137
138 auto roundedx = cutlass::round_up(input.size(0), kWarpX);
139 auto roundedy = cutlass::round_up(input.size(1), kWarpY);
140
141 Params p;
142 p.input = (uint16_t const*)input.data_ptr();
143 p.input_dim0 = input.size(0);
144 p.input_dim1 = input.size(1);
145 p.threads_masks = (uint64_t const*)threads_masks.data_ptr();
146
147 TORCH_CHECK(threads_masks.dim() == 3);
148 TORCH_CHECK(threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
149 TORCH_CHECK(threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
150 TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
151 TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
152 TORCH_CHECK(threads_masks.stride(2) == 1);
153 TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
154
155 at::Tensor output = at::empty({p.input_dim0, p.input_dim1}, input.options());
156 TORCH_INTERNAL_ASSERT(output.stride(-1) == 1, "expected RowMajor?");
157 p.output = (uint16_t*)output.data_ptr();
158
159 bool inputRowMajor = input.stride(-1) == 1;
160 bool outputRowMajor = output.stride(-1) == 1;
161 p.input_stride = input.stride(inputRowMajor ? 0 : 1);
162 p.output_stride = output.stride(outputRowMajor ? 0 : 1);
163 at::cuda::CUDAGuard device_guard(input.device());
164
165 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
166 size_t smem_bytes = 0;
167 if (inputRowMajor && outputRowMajor) {
168 sparse_semi_structured_apply_dense_k<true, true>
169 <<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
170 } else if (!inputRowMajor && outputRowMajor) {
171 sparse_semi_structured_apply_dense_k<false, true>
172 <<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
173 } else {
174 TORCH_CHECK(
175 false,
176 "Unsupported configuration: `input` is ",
177 inputRowMajor ? "RowMajor" : "ColMajor",
178 ", and `output` is ",
179 outputRowMajor ? "RowMajor" : "ColMajor");
180 }
181 C10_CUDA_KERNEL_LAUNCH_CHECK();
182 return output;
183 #endif
184 }
185
186 } // namespace
187