xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ScalarOps.h>
2 #include <ATen/Tensor.h>
3 #include <ATen/Functions.h>
4 #include <ATen/autocast_mode.h>
5 #include <c10/cuda/CUDAGuard.h>
6 
7 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
8 #else
9 #include <ATen/native/sparse/cuda/ComputeSparseTile.h>
10 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
11 #endif
12 
13 namespace at::native {
14 
15 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
16 #else
17 struct Params {
18   uint64_t const* threads_masks;
19 
20   uint16_t const* input;
21   int64_t input_stride;
22   int64_t input_dim0;
23   int64_t input_dim1;
24 
25   uint16_t* output;
26   int64_t output_stride;
27 
28   __host__ dim3 getBlocksGrid() const {
29     return dim3(
30         cutlass::ceil_div(input_dim0, kWarpX),
31         cutlass::ceil_div(input_dim1, kWarpY),
32         1);
33   }
34 
35   static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
36     return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
37   }
38 
39   CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
40     Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
41     gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
42     int64_t strideX = gridDim.y * getThreadsGrid().y;
43     gmem_threads_masks +=
44         (blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
45     return gmem_threads_masks;
46   }
47 };
48 
49 template <bool kInputRowMajor = true, bool kOutputRowMajor = true>
50 __global__ void __launch_bounds__(32 /* num_threads */, 32) sparse_semi_structured_apply_dense_k(Params p) {
51   using Fragment = cutlass::Array<uint16_t, 8>;
52 
53   // Top-left of the 8x8 tile we own
54   int warp_x = blockIdx.x * kWarpX;
55   int warp_y = blockIdx.y * kWarpY;
56   int x = warp_x + threadIdx.x * kThreadX;
57   int y = warp_y + threadIdx.y * kThreadY;
58 
59   uint16_t* output = p.output + x * p.output_stride + y;
60   Tile8x8Masks indices = *p.getCurrentThreadIndices();
61 
62   // Load dense
63   Fragment lines[8];
64   if (kInputRowMajor) {
65     uint16_t const* input = p.input + x * p.input_stride + y;
66     CUTLASS_PRAGMA_UNROLL
67     for (int i = 0; i < 8; ++i) {
68       cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
69           lines[i], input + i * p.input_stride, true);
70     }
71   } else {
72     uint16_t const* input = p.input + x + y * p.input_stride;
73     Fragment columns[8];
74     CUTLASS_PRAGMA_UNROLL
75     for (int i = 0; i < 8; ++i) {
76       cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
77           columns[i], input + i * p.input_stride, true);
78     }
79     CUTLASS_PRAGMA_UNROLL
80     for (int i = 0; i < 8; ++i) {
81       CUTLASS_PRAGMA_UNROLL
82       for (int j = 0; j < 8; ++j) {
83         lines[i][j] = columns[j][i].get();
84       }
85     }
86   }
87 
88   CUTLASS_PRAGMA_UNROLL
89   for (int row = 0; row < 2; ++row) {
90     Indices4x4 masks[2];
91     if (row == 0) {
92       masks[0] = indices.a;
93       masks[1] = indices.b;
94     } else {
95       masks[0] = indices.c;
96       masks[1] = indices.d;
97     }
98 
99     // Apply mask
100     CUTLASS_PRAGMA_UNROLL
101     for (int m = 0; m < 2; ++m) {
102       CUTLASS_PRAGMA_UNROLL
103       for (int r = 0; r < 4; ++r) {
104         CUTLASS_PRAGMA_UNROLL
105         for (int c = 0; c < 4; ++c) {
106           lines[4 * row + r][4 * m + c] = lines[4 * row + r][4 * m + c] *
107               int((masks[m] >> (4 * r + c)) & 1);
108         }
109       }
110     }
111   }
112   static_assert(kOutputRowMajor, "Transpose here for ColMajor output");
113   // Save dense with zeros
114   CUTLASS_PRAGMA_UNROLL
115   for (int i = 0; i < 8; ++i) {
116     cutlass::arch::global_store<Fragment, sizeof(Fragment)>(
117         lines[i], output + i * p.output_stride, true);
118   }
119 }
120 #endif
121 
_sparse_semi_structured_apply_dense(const Tensor & input,const Tensor & threads_masks)122 Tensor _sparse_semi_structured_apply_dense(
123     const Tensor& input,
124     const Tensor& threads_masks) {
125 
126 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
127   AT_ERROR("_sparse_semi_structured_apply_dense: not supported");
128   return Tensor{};
129 #else
130   TORCH_CHECK(
131       input.scalar_type() == at::ScalarType::Half ||
132           input.scalar_type() == at::ScalarType::BFloat16,
133       "Unsupported `input` dtype");
134   TORCH_CHECK(
135       input.stride(0) == 1 || input.stride(1) == 1,
136       "`input` should be either RowMajor or ColMajor. Invalid memory layout - try .contiguous()?");
137 
138   auto roundedx = cutlass::round_up(input.size(0), kWarpX);
139   auto roundedy = cutlass::round_up(input.size(1), kWarpY);
140 
141   Params p;
142   p.input = (uint16_t const*)input.data_ptr();
143   p.input_dim0 = input.size(0);
144   p.input_dim1 = input.size(1);
145   p.threads_masks = (uint64_t const*)threads_masks.data_ptr();
146 
147   TORCH_CHECK(threads_masks.dim() == 3);
148   TORCH_CHECK(threads_masks.size(0) == p.getBlocksGrid().x * p.getThreadsGrid().x);
149   TORCH_CHECK(threads_masks.size(1) == p.getBlocksGrid().y * p.getThreadsGrid().y);
150   TORCH_CHECK(threads_masks.stride(1) == sizeof(p.threads_masks[0]));
151   TORCH_CHECK(threads_masks.size(2) == sizeof(p.threads_masks[0]));
152   TORCH_CHECK(threads_masks.stride(2) == 1);
153   TORCH_CHECK(threads_masks.scalar_type() == at::ScalarType::Byte);
154 
155   at::Tensor output = at::empty({p.input_dim0, p.input_dim1}, input.options());
156   TORCH_INTERNAL_ASSERT(output.stride(-1) == 1, "expected RowMajor?");
157   p.output = (uint16_t*)output.data_ptr();
158 
159   bool inputRowMajor = input.stride(-1) == 1;
160   bool outputRowMajor = output.stride(-1) == 1;
161   p.input_stride = input.stride(inputRowMajor ? 0 : 1);
162   p.output_stride = output.stride(outputRowMajor ? 0 : 1);
163   at::cuda::CUDAGuard device_guard(input.device());
164 
165   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
166   size_t smem_bytes = 0;
167   if (inputRowMajor && outputRowMajor) {
168     sparse_semi_structured_apply_dense_k<true, true>
169         <<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
170   } else if (!inputRowMajor && outputRowMajor) {
171     sparse_semi_structured_apply_dense_k<false, true>
172         <<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
173   } else {
174     TORCH_CHECK(
175         false,
176         "Unsupported configuration: `input` is ",
177         inputRowMajor ? "RowMajor" : "ColMajor",
178         ", and `output` is ",
179         outputRowMajor ? "RowMajor" : "ColMajor");
180   }
181   C10_CUDA_KERNEL_LAUNCH_CHECK();
182   return output;
183 #endif
184 }
185 
186 } // namespace
187