1 #include <ATen/ATen.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAUtils.h>
4 #include <ATen/Dispatch.h>
5
6 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
7 #else
8 #include <cuda_runtime.h>
9 #include <cutlass/cutlass.h>
10 #include <cutlass/layout/layout.h>
11 #include <cutlass/tensor_ref.h>
12 #include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
13 #include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
14 #endif
15
16 #include <type_traits>
17 #include <tuple>
18
19 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
20 #else
21 #define CUTLASS_STATUS_CHECK(status) \
22 { \
23 TORCH_CHECK(status == cutlass::Status::kSuccess, \
24 __func__, " : CUTLASS error: ", \
25 cutlassGetStatusString(status)); \
26 }
27 #endif
28
29 namespace at::native {
30
31 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
32 #else
33 // Wrapper function for CUTLASS sparse GEMM implementation, used
34 // solely to simplify dispatching from
35 // sparse_semi_structured_mad_op() function below.
36 template <
37 typename ElementInputA,
38 typename ElementInputB,
39 typename ElementOutput,
40 typename ElementAccumulator,
41 typename ThreadblockShape,
42 typename WarpShape,
43 typename InstructionShape,
44 typename LayoutInputA,
45 typename LayoutInputB,
46 bool use_tensor_c>
47 void spgemm_cutlass(
48 const Tensor& tensor_a, const at::IntArrayRef::value_type& tensor_a_stride,
49 const Tensor& tensor_b, const at::IntArrayRef::value_type& tensor_b_stride,
50 const Tensor& tensor_c, const Tensor& tensor_e, const Scalar& alpha,
51 const Scalar& beta, Tensor& tensor_d) {
52 // Fix CUTLASS sparse GEMM template arguments that are not
53 // provided as template argument of this function, and create an
54 // alias for particular instantiation of this template.
55 using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format.
56 using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance.
57 using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
58 using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
59 constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
60 using Operator = cutlass::arch::OpMultiplyAdd;
61 constexpr int NumEVTEpilogueStages = 1;
62
63 constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
64 constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
65 constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
66
67 using ElementComputeEpilogue = ElementAccumulator; // Typically slightly slower, but more precise than if ElementOutput used.
68 constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
69 using ElementC = ElementOutput;
70 using LayoutC = LayoutOutput;
71 constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
72
73 using TensorCTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
74 ThreadblockShape,
75 WarpShape,
76 ElementC,
77 AlignmentC,
78 NumEVTEpilogueStages>;
79 using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
80 ThreadblockShape,
81 WarpShape,
82 ElementOutput,
83 AlignmentOutput,
84 NumEVTEpilogueStages>;
85
86 using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
87
88 using Alpha =
89 cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementComputeEpilogue>;
90 using AlphaArguments = typename Alpha::Arguments;
91
92 using ApplyAlpha = cutlass::epilogue::threadblock::VisitorCompute<
93 cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue,
94 cutlass::FloatRoundStyle::round_to_nearest>;
95 using EVTApplyAlpha = cutlass::epilogue::threadblock::Sm80EVT<
96 ApplyAlpha,
97 Alpha,
98 Accum>;
99
100 using Beta =
101 cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementComputeEpilogue>;
102 using BetaArguments = typename Beta::Arguments;
103
104 using TensorCScalar =
105 cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
106 using TensorCTensor =
107 cutlass::epilogue::threadblock::VisitorColBroadcast<
108 TensorCTileThreadMap,
109 ElementC,
110 cute::Stride<cute::_1, cute::_0, int64_t>>;
111 using TensorC = std::conditional_t<use_tensor_c, TensorCTensor, TensorCScalar>;
112 using TensorCArguments = typename TensorC::Arguments;
113
114 using ApplyBeta = cutlass::epilogue::threadblock::VisitorCompute<
115 cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue,
116 cutlass::FloatRoundStyle::round_to_nearest>;
117 using EVTApplyBeta = cutlass::epilogue::threadblock::Sm80EVT<
118 ApplyBeta,
119 Beta,
120 TensorC>;
121
122 using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
123 cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
124 cutlass::FloatRoundStyle::round_to_nearest>;
125 using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
126 ApplySum,
127 EVTApplyAlpha,
128 EVTApplyBeta>;
129
130 using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
131 OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
132 cute::Stride<int64_t, cute::_1, int64_t>>;
133
134 using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
135 Output,
136 EVTApplySum>;
137
138 using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
139 ElementInputA,
140 LayoutInputA,
141 ElementInputB,
142 LayoutInputB,
143 ElementC,
144 LayoutC,
145 ElementAccumulator,
146 MMAOp,
147 SmArch,
148 ThreadblockShape,
149 WarpShape,
150 InstructionShape,
151 EVTOutput,
152 SwizzleThreadBlock,
153 NumStages,
154 AlignmentInputA,
155 AlignmentInputB,
156 Operator,
157 NumEVTEpilogueStages>;
158
159 // Datatype and layout of metadata matrix are inferred from sparse
160 // GEMM template.
161 using ElementInputE = typename Gemm::ElementE;
162 using LayoutInputE = cutlass::layout::RowMajor;
163 using ReorderedLayoutInputE = typename Gemm::LayoutE;
164 static_assert(
165 std::is_same<ReorderedLayoutInputE,
166 cutlass::layout::ColumnMajorInterleaved<2>>::value,
167 "Matrix layout used by CUTLASS for reordered metadata for sparse GEMM "
168 "change, thus code doing conversions from/to dense matrix has to be "
169 "updated.");
170
171 constexpr auto kSparse = Gemm::kSparse;
172 constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
173
174 // Operand sizes.
175 const int length_m = tensor_a.size(0);
176 const int length_k = tensor_b.size(0);
177 const int length_n = tensor_b.size(1);
178 const auto tensor_e_ncols = length_k / kSparse / kElementsPerElementE;
179
180 // Determine PyTorch datatype for the metadata matrix.
181 auto tensor_e_dtype = at::kChar;
182 switch (sizeof(ElementInputE)) {
183 case 2:
184 tensor_e_dtype = at::kShort;
185 break;
186 case 4:
187 tensor_e_dtype = at::kInt;
188 break;
189 default:
190 AT_ERROR(__func__, ": invalid size of meta tensor datatype "
191 "encountered");
192 }
193 TORCH_CHECK(tensor_e.dtype() == tensor_e_dtype,
194 __func__, " : Expected meta datatype ", tensor_e_dtype,
195 ", but got ", tensor_e.dtype());
196
197 // Prepare arguments for CUTLASS sparse GEMM kernel.
198 cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
199 LayoutInputA layout_a(tensor_a_stride);
200 LayoutInputB layout_b(tensor_b_stride);
201 auto tensor_a_device_ref =
202 cutlass::TensorRef<ElementInputA, LayoutInputA>(
203 (ElementInputA*)tensor_a.data_ptr(), layout_a);
204 auto tensor_b_device_ref =
205 cutlass::TensorRef<ElementInputB, LayoutInputB>(
206 (ElementInputB*)tensor_b.data_ptr(), layout_b);
207 auto tensor_e_reordered_device_ref =
208 cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
209 (ElementInputE*)tensor_e.data_ptr(),
210 ReorderedLayoutInputE::packed({length_m, tensor_e_ncols}));
211
212 AlphaArguments alpha_arguments{
213 [&]() -> AlphaArguments {
214 if constexpr (std::is_same<ElementComputeEpilogue, cutlass::half_t>::value ||
215 std::is_same<ElementComputeEpilogue, cutlass::bfloat16_t>::value) {
216 return {ElementComputeEpilogue{alpha.to<float>()}};
217 } else {
218 return {alpha.to<ElementComputeEpilogue>()};
219 }
220 }()
221 };
222 BetaArguments beta_arguments{
223 [&]() -> BetaArguments {
224 if constexpr (std::is_same<ElementComputeEpilogue, cutlass::half_t>::value ||
225 std::is_same<ElementComputeEpilogue, cutlass::bfloat16_t>::value) {
226 return {ElementComputeEpilogue{beta.to<float>()}};
227 } else {
228 return {beta.to<ElementComputeEpilogue>()};
229 }
230 }()
231 };
232 TensorCArguments tensor_c_arguments{
233 [&]() -> TensorCArguments {
234 if constexpr (use_tensor_c) {
235 return {(ElementC*)tensor_c.data_ptr(),
236 ElementC(0),
237 {cute::_1{}, cute::_0{}, problem_size.m()}};
238 } else {
239 return {ElementC(0)};
240 }
241 }()
242 };
243 typename Output::Arguments output_arguments{
244 (ElementOutput*)tensor_d.data_ptr(),
245 {problem_size.n(), cute::_1{}, problem_size.mn().product()}
246 };
247 typename EVTOutput::Arguments callback_arguments{
248 {
249 {
250 alpha_arguments, // Alpha
251 {}, // Accum
252 {} // ApplyAlpha
253 }, // EVTApplyAlpha
254 {
255 beta_arguments, // Beta
256 tensor_c_arguments, // TensorC
257 {} // ApplyBeta
258 }, // EVTApplyBeta
259 {} // ApplySum
260 }, // EVTApplySum
261 output_arguments // Output
262 }; // EVTOutput
263
264 // Create a tuple of CUTLASS sparse GEMM kernel arguments.
265 typename Gemm::Arguments arguments{
266 problem_size,
267 tensor_a_device_ref,
268 tensor_b_device_ref,
269 tensor_e_reordered_device_ref,
270 callback_arguments};
271
272 cutlass::Status status;
273
274 // Create CUTLASS sparse GEMM kernel object.
275 Gemm gemm_op;
276
277 // Verify that sparse GEMM operation with given arguments can be
278 // performed by CUTLASS.
279 status = gemm_op.can_implement(arguments);
280 CUTLASS_STATUS_CHECK(status);
281
282 // Allocate workspace for CUTLASS sparse GEMM kernel.
283 const auto workspace_size = Gemm::get_workspace_size(arguments);
284 auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
285 at::TensorOptions().dtype(at::kByte));
286
287 // Initialize CUTLASS sparse GEMM object.
288 status = gemm_op.initialize(arguments, workspace.data_ptr(),
289 at::cuda::getCurrentCUDAStream());
290 CUTLASS_STATUS_CHECK(status);
291
292 // Perform sparse GEMM operation.
293 status = gemm_op.run(at::cuda::getCurrentCUDAStream());
294 CUTLASS_STATUS_CHECK(status);
295
296 C10_CUDA_KERNEL_LAUNCH_CHECK();
297 }
298
299 // Dispatch according to the input tensors layouts combination.
300 template <
301 typename ElementInputA,
302 typename ElementInputB,
303 typename ElementOutput,
304 typename ElementAccumulator,
305 typename ThreadblockShape,
306 typename WarpShape,
307 typename InstructionShape,
308 bool EnableRowMajorRowMajorLayouts,
309 bool EnableRowMajorColumnMajorLayouts,
310 bool EnableColumnMajorRowMajorLayouts,
311 bool EnableColumnMajorColumnMajorLayouts,
312 bool use_tensor_c>
313 void spgemm_cutlass_dispatch_layouts(
314 const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
315 const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta,
316 Tensor& tensor_d) {
317 // Determine layouts (row-major or column-major) of input tensors.
318 const auto strides_a = tensor_a.strides();
319 auto tensor_a_row_major = strides_a[1] == 1;
320 auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1];
321 const auto strides_b = tensor_b.strides();
322 auto tensor_b_row_major = strides_b[1] == 1;
323 auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1];
324
325 // Perform dispatching.
326 if constexpr (EnableRowMajorRowMajorLayouts) {
327 if (tensor_a_row_major && tensor_b_row_major) {
328 spgemm_cutlass<
329 ElementInputA,
330 ElementInputB,
331 ElementOutput,
332 ElementAccumulator,
333 ThreadblockShape,
334 WarpShape,
335 InstructionShape,
336 cutlass::layout::RowMajor,
337 cutlass::layout::RowMajor,
338 use_tensor_c>(
339 tensor_a,
340 tensor_a_stride,
341 tensor_b,
342 tensor_b_stride,
343 tensor_c,
344 tensor_e,
345 alpha,
346 beta,
347 tensor_d);
348 return;
349 }
350 }
351 if constexpr (EnableRowMajorColumnMajorLayouts) {
352 if (tensor_a_row_major && !tensor_b_row_major) {
353 spgemm_cutlass<
354 ElementInputA,
355 ElementInputB,
356 ElementOutput,
357 ElementAccumulator,
358 ThreadblockShape,
359 WarpShape,
360 InstructionShape,
361 cutlass::layout::RowMajor,
362 cutlass::layout::ColumnMajor,
363 use_tensor_c>(
364 tensor_a,
365 tensor_a_stride,
366 tensor_b,
367 tensor_b_stride,
368 tensor_c,
369 tensor_e,
370 alpha,
371 beta,
372 tensor_d);
373 return;
374 }
375 }
376 if constexpr (EnableColumnMajorRowMajorLayouts) {
377 if (!tensor_a_row_major && tensor_b_row_major) {
378 spgemm_cutlass<
379 ElementInputA,
380 ElementInputB,
381 ElementOutput,
382 ElementAccumulator,
383 ThreadblockShape,
384 WarpShape,
385 InstructionShape,
386 cutlass::layout::ColumnMajor,
387 cutlass::layout::RowMajor,
388 use_tensor_c>(
389 tensor_a,
390 tensor_a_stride,
391 tensor_b,
392 tensor_b_stride,
393 tensor_c,
394 tensor_e,
395 alpha,
396 beta,
397 tensor_d);
398 return;
399 }
400 }
401 if constexpr (EnableColumnMajorColumnMajorLayouts) {
402 if (!tensor_a_row_major && !tensor_b_row_major) {
403 spgemm_cutlass<
404 ElementInputA,
405 ElementInputB,
406 ElementOutput,
407 ElementAccumulator,
408 ThreadblockShape,
409 WarpShape,
410 InstructionShape,
411 cutlass::layout::ColumnMajor,
412 cutlass::layout::ColumnMajor,
413 use_tensor_c>(
414 tensor_a,
415 tensor_a_stride,
416 tensor_b,
417 tensor_b_stride,
418 tensor_c,
419 tensor_e,
420 alpha,
421 beta,
422 tensor_d);
423 return;
424 }
425 }
426
427 AT_ERROR(__func__, "_dispatch_layouts: Combination of ",
428 tensor_a_row_major ? "row-major" : "column_major", " and ",
429 tensor_b_row_major ? "row-major" : "column_major",
430 " layouts for input tensors is not supported");
431 }
432
433 // Dispatch according to the tensor_c tensor being provided or not.
434 template <
435 typename ElementInputA,
436 typename ElementInputB,
437 typename ElementOutput,
438 typename ElementAccumulator,
439 typename ThreadblockShape,
440 typename WarpShape,
441 typename InstructionShape,
442 bool EnableRowMajorRowMajorLayouts,
443 bool EnableRowMajorColumnMajorLayouts,
444 bool EnableColumnMajorRowMajorLayouts,
445 bool EnableColumnMajorColumnMajorLayouts>
446 void spgemm_cutlass_dispatch_layouts_tensor_c(
447 const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
448 const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta,
449 Tensor& tensor_d) {
450 if (tensor_c.numel() > 0) {
451 spgemm_cutlass_dispatch_layouts<
452 ElementInputA,
453 ElementInputB,
454 ElementOutput,
455 ElementAccumulator,
456 ThreadblockShape,
457 WarpShape,
458 InstructionShape,
459 EnableRowMajorRowMajorLayouts,
460 EnableRowMajorColumnMajorLayouts,
461 EnableColumnMajorRowMajorLayouts,
462 EnableColumnMajorColumnMajorLayouts,
463 true>(
464 tensor_a,
465 tensor_b,
466 tensor_c,
467 tensor_e,
468 alpha,
469 beta,
470 tensor_d);
471 } else {
472 spgemm_cutlass_dispatch_layouts<
473 ElementInputA,
474 ElementInputB,
475 ElementOutput,
476 ElementAccumulator,
477 ThreadblockShape,
478 WarpShape,
479 InstructionShape,
480 EnableRowMajorRowMajorLayouts,
481 EnableRowMajorColumnMajorLayouts,
482 EnableColumnMajorRowMajorLayouts,
483 EnableColumnMajorColumnMajorLayouts,
484 false>(
485 tensor_a,
486 tensor_b,
487 tensor_c,
488 tensor_e,
489 alpha,
490 beta,
491 tensor_d);
492 }
493 }
494 #endif
495
496 // Perform multiply-add operation, using corresponding CUTLASS
497 // sparse GEMM kernel, to given arguments:
498 // result = alpha * mat1 @ mat2 + beta * input
499 // The "mat2" tensor is a dense tensor, while the "mat1" tensor is a
500 // sparse semi-structured matrix. The "input" tensor is optional; if
501 // provided, it should be a vector, with the number of elements equal
502 // to the number of rows of "mat1" matrix. It is assumed that "mat1"
503 // and "mat2" are 2D tensors, supplied either in row-major or
504 // column-major layouts (different layouts between these two tensors
505 // are OK, but not all combinations of formats are supported for some
506 // datatypes of these matrices). The "mat1_meta" argument contains
507 // sparse semi-strucutred metadata.
508 //
509 // There exists numerous limitations of CUTLASS sparse GEMM kernel,
510 // with regards to sizes and alignments of input tensors, their
511 // layouts and datatypes, and so on; this is the reason for large
512 // number of checks throughout the code.
513 //
514 // TODO: The "input" tensor has to be a vector, such that it could be
515 // broadcasted to columns of mat1 * mat2. The case of broadcasting to
516 // rows of mat1 * mat2 could be also supported, if "input" tensor is a
517 // vector of corresponding length; and same for the case when "input"
518 // tensor is a matrix of same size as mat1 * mat2 product. If these
519 // updates made here, then remember to update corresponding bits in
520 // the Inductor code that are handling meta registrations and
521 // lowerings of aten._sparse_semi_structured_mm and
522 // aten._sparse_semi_structured_addmm operators.
sparse_semi_structured_mad_op(const Tensor & mat1,const Tensor & mat1_meta,const Tensor & mat2,const std::optional<Tensor> & input_opt,const Scalar & alpha,const Scalar & beta,const std::optional<c10::ScalarType> out_dtype_opt)523 Tensor sparse_semi_structured_mad_op(
524 const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2,
525 const std::optional<Tensor>& input_opt, const Scalar& alpha,
526 const Scalar& beta, const std::optional<c10::ScalarType> out_dtype_opt) {
527 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
528 AT_ERROR(__func__, " : CUTLASS not supported");
529 return Tensor{};
530 #else
531 // No need to check that all tensors are on CUDA device, as this
532 // is provided by dispatch.
533
534 const auto& input = input_opt.value_or(Tensor{});
535 const auto out_dtype = out_dtype_opt.value_or(mat2.scalar_type());
536
537 // For now, only CC 8.x devices are supported.
538 const auto dprops = at::cuda::getCurrentDeviceProperties();
539 const auto is_sm8x = dprops->major == 8;
540 TORCH_CHECK(is_sm8x,
541 __func__, " : Supported only on GPUs with compute capability "
542 "8.x");
543
544 // Validate datatypes of input tensors.
545 TORCH_CHECK(mat2.dtype() == at::kChar ||
546 mat2.dtype() == at::kHalf ||
547 mat2.dtype() == at::kBFloat16 ||
548 mat2.dtype() == at::kFloat,
549 __func__, " : The mat2 datatype ", mat2.dtype(),
550 " is not supported");
551 TORCH_CHECK(mat1.dtype() == mat2.dtype(),
552 __func__, " : Expected mat1 datatype ", mat2.dtype(),
553 ", but got ", mat1.dtype());
554 if (input.numel() != 0) {
555 TORCH_CHECK(input.dtype() == out_dtype,
556 __func__, " : Expected input datatype ", out_dtype,
557 ", but got ", input.dtype());
558 }
559
560 // Validate layouts of input tensors.
561 TORCH_CHECK(mat1.layout() == Layout::Strided,
562 __func__, " : Expected mat1 argument to be strided, but got "
563 "layout ", mat1.layout());
564 TORCH_CHECK(mat1.dim() == 2,
565 __func__, " : Expected mat1 argument to be 2D tensor, got ",
566 mat1.dim(), " dims");
567 const auto strides_a = mat1.strides();
568 TORCH_CHECK(strides_a[0] == 1 || strides_a[1] == 1,
569 __func__, " : Invalid strides for mat1 argument: row stride = ",
570 strides_a[0], ", column stride = ", strides_a[1]);
571 TORCH_CHECK(mat2.layout() == Layout::Strided,
572 __func__, " : Expected mat2 argument to be "
573 "strided, but got layout ", mat2.layout());
574 TORCH_CHECK(mat2.dim() == 2,
575 __func__, " : Expected mat2 argument to be 2D tensor, got ",
576 mat2.dim(), " dims");
577 const auto strides_b = mat2.strides();
578 TORCH_CHECK(strides_b[0] == 1 || strides_b[1] == 1,
579 __func__, " : Invalid strides for mat2 argument: row stride = ",
580 strides_b[0], ", column stride = ", strides_b[1]);
581 if (input.numel() != 0) {
582 TORCH_CHECK(input.layout() == Layout::Strided,
583 __func__, " : Expected input argument to be strided, but "
584 "got layout ", input.layout());
585 TORCH_CHECK(input.dim() == 1,
586 __func__, " : Expected input argument to be 1D tensor, "
587 "got ", input.dim(), " dims");
588 }
589
590 // Validate sizes of input tensors.
591 TORCH_CHECK(mat1.size(1) == mat2.size(0) / 2,
592 __func__, " : Expected mat1 argument to have ",
593 mat2.size(0) / 2, " columns, but got ", mat1.size(1));
594 if (input.numel() != 0) {
595 TORCH_CHECK(input.size(0) == mat1.size(0),
596 __func__, " : Expected input argument to have ",
597 mat1.size(0), " elements, but got ", input.size(0));
598 }
599
600 // Introduce alias names for arguments, according to the CUTLASS
601 // naming conventions.
602 const auto& tensor_a = mat1;
603 const auto& tensor_b = mat2;
604 const auto& tensor_c = input;
605 const auto& tensor_e = mat1_meta;
606
607 // Create output tensor.
608 Tensor tensor_d =
609 tensor_b.new_empty({tensor_a.size(0), tensor_b.size(1)},
610 at::TensorOptions().dtype(out_dtype));
611
612 // Call wrapper function for CUTLASS sparse GEMM, dispatching on
613 // the input datatype, and then on input tensors layouts.
614 // According to the input tensors datatypes and layouts,
615 // corresponding template arguments are supplied for instantiating
616 // the wrapper function. The tile sizes template arguments are
617 // selected according to the CUTLASS profiler results, for number
618 // of runs.
619 AT_DISPATCH_SWITCH(
620 tensor_a.scalar_type(),
621 "sparse_semi_structured_mad_op",
622 AT_DISPATCH_CASE(
623 at::ScalarType::Char,
624 [&]() {
625 using ElementInputA = int8_t;
626 using ElementInputB = int8_t;
627 using ElementAccumulator = int32_t;
628 using ThreadblockShape =
629 cutlass::gemm::GemmShape<128, 128, 128>;
630 using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
631 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
632 const auto EnableRowMajorRowMajorLayouts = false;
633 const auto EnableRowMajorColumnMajorLayouts = true;
634 const auto EnableColumnMajorRowMajorLayouts = false;
635 const auto EnableColumnMajorColumnMajorLayouts = false;
636 if (out_dtype == at::kInt) {
637 using ElementOutput = int32_t;
638 spgemm_cutlass_dispatch_layouts_tensor_c<
639 ElementInputA,
640 ElementInputB,
641 ElementOutput,
642 ElementAccumulator,
643 ThreadblockShape,
644 WarpShape,
645 InstructionShape,
646 EnableRowMajorRowMajorLayouts,
647 EnableRowMajorColumnMajorLayouts,
648 EnableColumnMajorRowMajorLayouts,
649 EnableColumnMajorColumnMajorLayouts>(
650 tensor_a,
651 tensor_b,
652 tensor_c,
653 tensor_e,
654 alpha,
655 beta,
656 tensor_d);
657 } else if (out_dtype == at::kChar) {
658 using ElementOutput = int8_t;
659 spgemm_cutlass_dispatch_layouts_tensor_c<
660 ElementInputA,
661 ElementInputB,
662 ElementOutput,
663 ElementAccumulator,
664 ThreadblockShape,
665 WarpShape,
666 InstructionShape,
667 EnableRowMajorRowMajorLayouts,
668 EnableRowMajorColumnMajorLayouts,
669 EnableColumnMajorRowMajorLayouts,
670 EnableColumnMajorColumnMajorLayouts>(
671 tensor_a,
672 tensor_b,
673 tensor_c,
674 tensor_e,
675 alpha,
676 beta,
677 tensor_d);
678 }
679 })
680 AT_DISPATCH_CASE(
681 at::ScalarType::Half,
682 [&]() {
683 using ElementInputA = cutlass::half_t;
684 using ElementInputB = cutlass::half_t;
685 using ElementOutput = cutlass::half_t;
686 using ElementAccumulator = float;
687 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
688 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
689 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
690 const auto EnableRowMajorRowMajorLayouts = true;
691 const auto EnableRowMajorColumnMajorLayouts = true;
692 const auto EnableColumnMajorRowMajorLayouts = true;
693 const auto EnableColumnMajorColumnMajorLayouts = true;
694 spgemm_cutlass_dispatch_layouts_tensor_c<
695 ElementInputA,
696 ElementInputB,
697 ElementOutput,
698 ElementAccumulator,
699 ThreadblockShape,
700 WarpShape,
701 InstructionShape,
702 EnableRowMajorRowMajorLayouts,
703 EnableRowMajorColumnMajorLayouts,
704 EnableColumnMajorRowMajorLayouts,
705 EnableColumnMajorColumnMajorLayouts>(
706 tensor_a,
707 tensor_b,
708 tensor_c,
709 tensor_e,
710 alpha,
711 beta,
712 tensor_d);
713 })
714 AT_DISPATCH_CASE(
715 at::ScalarType::BFloat16,
716 [&]() {
717 using ElementInputA = cutlass::bfloat16_t;
718 using ElementInputB = cutlass::bfloat16_t;
719 using ElementOutput = cutlass::bfloat16_t;
720 using ElementAccumulator = float;
721 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
722 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
723 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
724 const auto EnableRowMajorRowMajorLayouts = true;
725 const auto EnableRowMajorColumnMajorLayouts = true;
726 const auto EnableColumnMajorRowMajorLayouts = true;
727 const auto EnableColumnMajorColumnMajorLayouts = true;
728 spgemm_cutlass_dispatch_layouts_tensor_c<
729 ElementInputA,
730 ElementInputB,
731 ElementOutput,
732 ElementAccumulator,
733 ThreadblockShape,
734 WarpShape,
735 InstructionShape,
736 EnableRowMajorRowMajorLayouts,
737 EnableRowMajorColumnMajorLayouts,
738 EnableColumnMajorRowMajorLayouts,
739 EnableColumnMajorColumnMajorLayouts>(
740 tensor_a,
741 tensor_b,
742 tensor_c,
743 tensor_e,
744 alpha,
745 beta,
746 tensor_d);
747 })
748 AT_DISPATCH_CASE(
749 at::ScalarType::Float,
750 [&]() {
751 using ElementInputA = float;
752 using ElementInputB = float;
753 using ElementOutput = float;
754 using ElementAccumulator = float;
755 using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
756 using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
757 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
758 const auto EnableRowMajorRowMajorLayouts = true;
759 const auto EnableRowMajorColumnMajorLayouts = true;
760 const auto EnableColumnMajorRowMajorLayouts = true;
761 const auto EnableColumnMajorColumnMajorLayouts = true;
762 spgemm_cutlass_dispatch_layouts_tensor_c<
763 ElementInputA,
764 ElementInputB,
765 ElementOutput,
766 ElementAccumulator,
767 ThreadblockShape,
768 WarpShape,
769 InstructionShape,
770 EnableRowMajorRowMajorLayouts,
771 EnableRowMajorColumnMajorLayouts,
772 EnableColumnMajorRowMajorLayouts,
773 EnableColumnMajorColumnMajorLayouts>(
774 tensor_a,
775 tensor_b,
776 tensor_c,
777 tensor_e,
778 alpha,
779 beta,
780 tensor_d);
781 }));
782
783 return tensor_d;
784 #endif
785 }
786
787 // Implementation of aten._sparse_semi_structured_mm operator.
_sparse_semi_structured_mm(const Tensor & mat1,const Tensor & mat1_meta,const Tensor & mat2,const std::optional<c10::ScalarType> out_dtype_opt)788 Tensor _sparse_semi_structured_mm(
789 const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2,
790 const std::optional<c10::ScalarType> out_dtype_opt) {
791 return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2,
792 std::optional<Tensor>(), 1, 0,
793 out_dtype_opt);
794 }
795
796 // Implementation of aten._sparse_semi_structured_addmm operator.
_sparse_semi_structured_addmm(const Tensor & input,const Tensor & mat1,const Tensor & mat1_meta,const Tensor & mat2,const Scalar & alpha,const Scalar & beta,const std::optional<c10::ScalarType> out_dtype_opt)797 Tensor _sparse_semi_structured_addmm(
798 const Tensor& input, const Tensor& mat1, const Tensor& mat1_meta,
799 const Tensor& mat2, const Scalar& alpha, const Scalar& beta,
800 const std::optional<c10::ScalarType> out_dtype_opt) {
801 return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2, input, alpha,
802 beta, out_dtype_opt);
803 }
804
805 } // namespace at::native
806
807 // Following is just for testing purposes.
808 namespace at::native {
809
810 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
811 #else
812 // Copied from tools/util/include/host_reorder.h, from CUTLASS source
813 // tree. This is for simplicity - namely, this file is not under
814 // include/cutlass in this tree, as other CUTLASS include files
815 // needed, so it would require changing PyTorch CMake configuration;
816 // furthermore, including this file produces build errors in PyTorch
817 // at the moment.
818 template <typename Element, typename LayoutDest, typename LayoutSrc>
819 static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
820 cutlass::TensorRef<Element, LayoutSrc> src,
821 const int problem_size_m, const int problem_size_k) {
822 for (int m = 0; m < problem_size_m; m++) {
823 for (int k = 0; k < problem_size_k; k++) {
824 // First reorder the rows.
825 int group = (sizeof(Element) == 2) ? 32 : 16;
826 int interweave = (sizeof(Element) == 2) ? 4 : 2;
827
828 int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
829 int dest_col = k;
830
831 // Next swizzle the 2x2 blocks from Z to N.
832 if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
833 ++dest_row;
834 --dest_col;
835 } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
836 --dest_row;
837 ++dest_col;
838 }
839
840 dest.at({dest_row, dest_col}) = src.at({m, k});
841 }
842 }
843 }
844 #endif
845
846 std::tuple<Tensor, Tensor>
_to_sparse_semi_structured(const Tensor & dense)847 _to_sparse_semi_structured(const Tensor& dense) {
848 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
849 AT_ERROR(__func__, " : CUTLASS not supported");
850 return std::make_tuple(Tensor{}, Tensor{});
851 #else
852 // Check dimensions of the dense matrix.
853 TORCH_CHECK(dense.dim() == 2,
854 __func__, " : Expected dense argument to be 2D tensor, got ",
855 dense.dim(), " dims");
856
857 // Determine PyTorch datatype for the metadata matrix.
858 auto meta_dtype = at::kChar;
859 auto ksparse = 0;
860 auto dense_elems_per_meta_elem = 0;
861 if (dense.dtype() == at::kChar) {
862 meta_dtype = at::kInt;
863 ksparse = 4;
864 dense_elems_per_meta_elem = 32;
865 } else if (dense.dtype() == at::kHalf || dense.dtype() == at::kBFloat16) {
866 meta_dtype = at::kShort;
867 ksparse = 4;
868 dense_elems_per_meta_elem = 16;
869 } else if (dense.dtype() == at::kFloat) {
870 meta_dtype = at::kShort;
871 ksparse = 2;
872 dense_elems_per_meta_elem = 8;
873 } else {
874 AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ",
875 dense.dtype(), " encountered");
876 }
877
878 const auto dense_nrows = dense.size(0);
879 const auto dense_ncols = dense.size(1);
880
881 if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) {
882 AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must "
883 "be divisible by ", (meta_dtype == at::kShort ? 32 : 16),
884 ", but it is ", dense_nrows);
885 }
886 if (dense_ncols % dense_elems_per_meta_elem != 0) {
887 AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix "
888 "must be divisible by ", dense_elems_per_meta_elem, ", but it is ",
889 dense_ncols);
890 }
891
892 const auto dense_cpu = dense.to("cpu");
893
894 const auto mask_cpu = dense_cpu != at::zeros({1}, dense_cpu.options());
895
896 const auto sparse_cpu =
897 dense_cpu.masked_select(mask_cpu).view({dense_nrows, dense_ncols / 2});
898
899 const auto meta_nrows = dense_nrows;
900 const auto meta_ncols = dense_ncols / dense_elems_per_meta_elem;
901 auto meta_cpu = dense_cpu.new_empty({meta_nrows, meta_ncols},
902 at::TensorOptions().dtype(meta_dtype));
903
904 auto* mask_cpu_ptr = mask_cpu.data_ptr<bool>();
905 for (auto i = 0; i < meta_nrows; ++i) {
906 for (auto j = 0; j < meta_ncols; ++j) {
907 uint64_t meta_val = 0;
908 for (auto k = 0; k < dense_elems_per_meta_elem / ksparse; ++k, mask_cpu_ptr += ksparse) {
909 const auto mask_elems =
910 (ksparse == 4) ? std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[1],
911 mask_cpu_ptr[2], mask_cpu_ptr[3])
912 : std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[0],
913 mask_cpu_ptr[1], mask_cpu_ptr[1]);
914 auto meta_quadruple = 0;
915 if (mask_elems == std::make_tuple(1, 1, 0, 0)) {
916 meta_quadruple = 4; // 0100
917 } else if (mask_elems == std::make_tuple(1, 0, 1, 0)) {
918 meta_quadruple = 8; // 1000
919 } else if (mask_elems == std::make_tuple(0, 1, 1, 0)) {
920 meta_quadruple = 9; // 1001
921 } else if (mask_elems == std::make_tuple(1, 0, 0, 1)) {
922 meta_quadruple = 12; // 1100
923 } else if (mask_elems == std::make_tuple(0, 1, 0, 1)) {
924 meta_quadruple = 13; // 1101
925 } else if (mask_elems == std::make_tuple(0, 0, 1, 1)) {
926 meta_quadruple = 14; // 1110
927 } else {
928 AT_ERROR("_to_sparse_semi_structured: dense argument does not match ",
929 (dense.dtype() != at::kFloat) ? "2:4" : "1:2",
930 "sparsity pattern");
931 }
932 meta_val = meta_val | (meta_quadruple << (4 * k));
933 }
934 const auto idx = i * meta_ncols + j;
935 if (meta_dtype == at::kShort) {
936 using MetaElement = int16_t;
937 const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
938 meta_cpu_ptr[idx] = (MetaElement)meta_val;
939 } else if (meta_dtype == at::kInt) {
940 using MetaElement = int32_t;
941 const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
942 meta_cpu_ptr[idx] = (MetaElement)meta_val;
943 }
944 }
945 }
946
947 auto meta_reordered_cpu = meta_cpu.new_empty({meta_nrows, meta_ncols});
948 using MetaLayout = cutlass::layout::RowMajor;
949 using MetaReorderedLayout = cutlass::layout::ColumnMajorInterleaved<2>;
950 if (meta_dtype == at::kShort) {
951 using MetaElement = int16_t;
952 auto meta_cpu_ref =
953 cutlass::TensorRef<MetaElement, MetaLayout>(
954 meta_cpu.data_ptr<MetaElement>(),
955 MetaLayout::packed({meta_nrows, meta_ncols}));
956 auto meta_reordered_cpu_ref =
957 cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
958 meta_reordered_cpu.data_ptr<MetaElement>(),
959 MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
960 reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
961 } else if (meta_dtype == at::kInt) {
962 using MetaElement = int32_t;
963 auto meta_cpu_ref =
964 cutlass::TensorRef<MetaElement, MetaLayout>(
965 meta_cpu.data_ptr<MetaElement>(),
966 MetaLayout::packed({meta_nrows, meta_ncols}));
967 auto meta_reordered_cpu_ref =
968 cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
969 meta_reordered_cpu.data_ptr<MetaElement>(),
970 MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
971 reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
972 }
973
974 return std::make_tuple(sparse_cpu.to(dense.device()),
975 meta_reordered_cpu.to(dense.device()));
976 #endif
977 }
978
979 } // namespace at::native
980