xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/cuda/CUDAUtils.h>
4 #include <ATen/Dispatch.h>
5 
6 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
7 #else
8 #include <cuda_runtime.h>
9 #include <cutlass/cutlass.h>
10 #include <cutlass/layout/layout.h>
11 #include <cutlass/tensor_ref.h>
12 #include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
13 #include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
14 #endif
15 
16 #include <type_traits>
17 #include <tuple>
18 
19 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
20 #else
21 #define CUTLASS_STATUS_CHECK(status)                                    \
22   {                                                                     \
23     TORCH_CHECK(status == cutlass::Status::kSuccess,                    \
24                 __func__, " : CUTLASS error: ",                         \
25                 cutlassGetStatusString(status));                        \
26   }
27 #endif
28 
29 namespace at::native {
30 
31 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
32 #else
33 // Wrapper function for CUTLASS sparse GEMM implementation, used
34 // solely to simplify dispatching from
35 // sparse_semi_structured_mad_op() function below.
36 template <
37     typename ElementInputA,
38     typename ElementInputB,
39     typename ElementOutput,
40     typename ElementAccumulator,
41     typename ThreadblockShape,
42     typename WarpShape,
43     typename InstructionShape,
44     typename LayoutInputA,
45     typename LayoutInputB,
46     bool use_tensor_c>
47 void spgemm_cutlass(
48     const Tensor& tensor_a, const at::IntArrayRef::value_type& tensor_a_stride,
49     const Tensor& tensor_b, const at::IntArrayRef::value_type& tensor_b_stride,
50     const Tensor& tensor_c, const Tensor& tensor_e, const Scalar& alpha,
51     const Scalar& beta, Tensor& tensor_d) {
52     // Fix CUTLASS sparse GEMM template arguments that are not
53     // provided as template argument of this function, and create an
54     // alias for particular instantiation of this template.
55     using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format.
56     using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance.
57     using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
58     using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
59     constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
60     using Operator = cutlass::arch::OpMultiplyAdd;
61     constexpr int NumEVTEpilogueStages = 1;
62 
63     constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
64     constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
65     constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
66 
67     using ElementComputeEpilogue = ElementAccumulator; // Typically slightly slower, but more precise than if ElementOutput used.
68     constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
69     using ElementC = ElementOutput;
70     using LayoutC = LayoutOutput;
71     constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
72 
73     using TensorCTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
74         ThreadblockShape,
75         WarpShape,
76         ElementC,
77         AlignmentC,
78         NumEVTEpilogueStages>;
79     using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
80         ThreadblockShape,
81         WarpShape,
82         ElementOutput,
83         AlignmentOutput,
84         NumEVTEpilogueStages>;
85 
86     using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
87 
88     using Alpha =
89         cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementComputeEpilogue>;
90     using AlphaArguments = typename Alpha::Arguments;
91 
92     using ApplyAlpha = cutlass::epilogue::threadblock::VisitorCompute<
93         cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue,
94         cutlass::FloatRoundStyle::round_to_nearest>;
95     using EVTApplyAlpha = cutlass::epilogue::threadblock::Sm80EVT<
96         ApplyAlpha,
97         Alpha,
98         Accum>;
99 
100     using Beta =
101         cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementComputeEpilogue>;
102     using BetaArguments = typename Beta::Arguments;
103 
104     using TensorCScalar =
105         cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
106     using TensorCTensor =
107         cutlass::epilogue::threadblock::VisitorColBroadcast<
108             TensorCTileThreadMap,
109             ElementC,
110             cute::Stride<cute::_1, cute::_0, int64_t>>;
111     using TensorC = std::conditional_t<use_tensor_c, TensorCTensor, TensorCScalar>;
112     using TensorCArguments = typename TensorC::Arguments;
113 
114     using ApplyBeta = cutlass::epilogue::threadblock::VisitorCompute<
115         cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue,
116         cutlass::FloatRoundStyle::round_to_nearest>;
117     using EVTApplyBeta = cutlass::epilogue::threadblock::Sm80EVT<
118         ApplyBeta,
119         Beta,
120         TensorC>;
121 
122     using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
123         cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
124         cutlass::FloatRoundStyle::round_to_nearest>;
125     using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
126         ApplySum,
127         EVTApplyAlpha,
128         EVTApplyBeta>;
129 
130     using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
131         OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
132         cute::Stride<int64_t, cute::_1, int64_t>>;
133 
134     using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
135         Output,
136         EVTApplySum>;
137 
138     using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
139         ElementInputA,
140         LayoutInputA,
141         ElementInputB,
142         LayoutInputB,
143         ElementC,
144         LayoutC,
145         ElementAccumulator,
146         MMAOp,
147         SmArch,
148         ThreadblockShape,
149         WarpShape,
150         InstructionShape,
151         EVTOutput,
152         SwizzleThreadBlock,
153         NumStages,
154         AlignmentInputA,
155         AlignmentInputB,
156         Operator,
157         NumEVTEpilogueStages>;
158 
159     // Datatype and layout of metadata matrix are inferred from sparse
160     // GEMM template.
161     using ElementInputE = typename Gemm::ElementE;
162     using LayoutInputE = cutlass::layout::RowMajor;
163     using ReorderedLayoutInputE = typename Gemm::LayoutE;
164     static_assert(
165         std::is_same<ReorderedLayoutInputE,
166                      cutlass::layout::ColumnMajorInterleaved<2>>::value,
167         "Matrix layout used by CUTLASS for reordered metadata for sparse GEMM "
168         "change, thus code doing conversions from/to dense matrix has to be "
169         "updated.");
170 
171     constexpr auto kSparse = Gemm::kSparse;
172     constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
173 
174     // Operand sizes.
175     const int length_m = tensor_a.size(0);
176     const int length_k = tensor_b.size(0);
177     const int length_n = tensor_b.size(1);
178     const auto tensor_e_ncols = length_k / kSparse / kElementsPerElementE;
179 
180     // Determine PyTorch datatype for the metadata matrix.
181     auto tensor_e_dtype = at::kChar;
182     switch (sizeof(ElementInputE)) {
183     case 2:
184         tensor_e_dtype = at::kShort;
185         break;
186     case 4:
187         tensor_e_dtype = at::kInt;
188         break;
189     default:
190         AT_ERROR(__func__, ": invalid size of meta tensor datatype "
191                  "encountered");
192     }
193     TORCH_CHECK(tensor_e.dtype() == tensor_e_dtype,
194                 __func__, " : Expected meta datatype ", tensor_e_dtype,
195                 ", but got ", tensor_e.dtype());
196 
197     // Prepare arguments for CUTLASS sparse GEMM kernel.
198     cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
199     LayoutInputA layout_a(tensor_a_stride);
200     LayoutInputB layout_b(tensor_b_stride);
201     auto tensor_a_device_ref =
202         cutlass::TensorRef<ElementInputA, LayoutInputA>(
203             (ElementInputA*)tensor_a.data_ptr(), layout_a);
204     auto tensor_b_device_ref =
205         cutlass::TensorRef<ElementInputB, LayoutInputB>(
206             (ElementInputB*)tensor_b.data_ptr(), layout_b);
207     auto tensor_e_reordered_device_ref =
208         cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
209             (ElementInputE*)tensor_e.data_ptr(),
210             ReorderedLayoutInputE::packed({length_m, tensor_e_ncols}));
211 
212     AlphaArguments alpha_arguments{
213         [&]() -> AlphaArguments {
214             if constexpr (std::is_same<ElementComputeEpilogue, cutlass::half_t>::value ||
215                           std::is_same<ElementComputeEpilogue, cutlass::bfloat16_t>::value) {
216                 return {ElementComputeEpilogue{alpha.to<float>()}};
217             } else {
218                 return {alpha.to<ElementComputeEpilogue>()};
219             }
220         }()
221     };
222     BetaArguments beta_arguments{
223         [&]() -> BetaArguments {
224             if constexpr (std::is_same<ElementComputeEpilogue, cutlass::half_t>::value ||
225                           std::is_same<ElementComputeEpilogue, cutlass::bfloat16_t>::value) {
226                 return {ElementComputeEpilogue{beta.to<float>()}};
227             } else {
228                 return {beta.to<ElementComputeEpilogue>()};
229             }
230         }()
231     };
232     TensorCArguments tensor_c_arguments{
233         [&]() -> TensorCArguments {
234             if constexpr (use_tensor_c) {
235                 return {(ElementC*)tensor_c.data_ptr(),
236                         ElementC(0),
237                         {cute::_1{}, cute::_0{}, problem_size.m()}};
238             } else {
239                 return {ElementC(0)};
240             }
241         }()
242     };
243     typename Output::Arguments output_arguments{
244         (ElementOutput*)tensor_d.data_ptr(),
245         {problem_size.n(), cute::_1{}, problem_size.mn().product()}
246     };
247     typename EVTOutput::Arguments callback_arguments{
248         {
249             {
250                 alpha_arguments,     // Alpha
251                 {},                  // Accum
252                 {}                   // ApplyAlpha
253             },                       // EVTApplyAlpha
254             {
255                 beta_arguments,      // Beta
256                 tensor_c_arguments,  // TensorC
257                 {}                   // ApplyBeta
258             },                       // EVTApplyBeta
259             {}                       // ApplySum
260         },                           // EVTApplySum
261         output_arguments             // Output
262     };                               // EVTOutput
263 
264     // Create a tuple of CUTLASS sparse GEMM kernel arguments.
265     typename Gemm::Arguments arguments{
266         problem_size,
267         tensor_a_device_ref,
268         tensor_b_device_ref,
269         tensor_e_reordered_device_ref,
270         callback_arguments};
271 
272     cutlass::Status status;
273 
274     // Create CUTLASS sparse GEMM kernel object.
275     Gemm gemm_op;
276 
277     // Verify that sparse GEMM operation with given arguments can be
278     // performed by CUTLASS.
279     status = gemm_op.can_implement(arguments);
280     CUTLASS_STATUS_CHECK(status);
281 
282     // Allocate workspace for CUTLASS sparse GEMM kernel.
283     const auto workspace_size = Gemm::get_workspace_size(arguments);
284     auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
285                                         at::TensorOptions().dtype(at::kByte));
286 
287     // Initialize CUTLASS sparse GEMM object.
288     status = gemm_op.initialize(arguments, workspace.data_ptr(),
289                                 at::cuda::getCurrentCUDAStream());
290     CUTLASS_STATUS_CHECK(status);
291 
292     // Perform sparse GEMM operation.
293     status = gemm_op.run(at::cuda::getCurrentCUDAStream());
294     CUTLASS_STATUS_CHECK(status);
295 
296     C10_CUDA_KERNEL_LAUNCH_CHECK();
297 }
298 
299 // Dispatch according to the input tensors layouts combination.
300 template <
301     typename ElementInputA,
302     typename ElementInputB,
303     typename ElementOutput,
304     typename ElementAccumulator,
305     typename ThreadblockShape,
306     typename WarpShape,
307     typename InstructionShape,
308     bool EnableRowMajorRowMajorLayouts,
309     bool EnableRowMajorColumnMajorLayouts,
310     bool EnableColumnMajorRowMajorLayouts,
311     bool EnableColumnMajorColumnMajorLayouts,
312     bool use_tensor_c>
313 void spgemm_cutlass_dispatch_layouts(
314     const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
315     const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta,
316     Tensor& tensor_d) {
317     // Determine layouts (row-major or column-major) of input tensors.
318     const auto strides_a = tensor_a.strides();
319     auto tensor_a_row_major = strides_a[1] == 1;
320     auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1];
321     const auto strides_b = tensor_b.strides();
322     auto tensor_b_row_major = strides_b[1] == 1;
323     auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1];
324 
325     // Perform dispatching.
326     if constexpr (EnableRowMajorRowMajorLayouts) {
327         if (tensor_a_row_major && tensor_b_row_major) {
328             spgemm_cutlass<
329                 ElementInputA,
330                 ElementInputB,
331                 ElementOutput,
332                 ElementAccumulator,
333                 ThreadblockShape,
334                 WarpShape,
335                 InstructionShape,
336                 cutlass::layout::RowMajor,
337                 cutlass::layout::RowMajor,
338                 use_tensor_c>(
339                 tensor_a,
340                 tensor_a_stride,
341                 tensor_b,
342                 tensor_b_stride,
343                 tensor_c,
344                 tensor_e,
345                 alpha,
346                 beta,
347                 tensor_d);
348             return;
349         }
350     }
351     if constexpr (EnableRowMajorColumnMajorLayouts) {
352         if (tensor_a_row_major && !tensor_b_row_major) {
353             spgemm_cutlass<
354                 ElementInputA,
355                 ElementInputB,
356                 ElementOutput,
357                 ElementAccumulator,
358                 ThreadblockShape,
359                 WarpShape,
360                 InstructionShape,
361                 cutlass::layout::RowMajor,
362                 cutlass::layout::ColumnMajor,
363                 use_tensor_c>(
364                 tensor_a,
365                 tensor_a_stride,
366                 tensor_b,
367                 tensor_b_stride,
368                 tensor_c,
369                 tensor_e,
370                 alpha,
371                 beta,
372                 tensor_d);
373             return;
374         }
375     }
376     if constexpr (EnableColumnMajorRowMajorLayouts) {
377         if (!tensor_a_row_major && tensor_b_row_major) {
378             spgemm_cutlass<
379                 ElementInputA,
380                 ElementInputB,
381                 ElementOutput,
382                 ElementAccumulator,
383                 ThreadblockShape,
384                 WarpShape,
385                 InstructionShape,
386                 cutlass::layout::ColumnMajor,
387                 cutlass::layout::RowMajor,
388                 use_tensor_c>(
389                 tensor_a,
390                 tensor_a_stride,
391                 tensor_b,
392                 tensor_b_stride,
393                 tensor_c,
394                 tensor_e,
395                 alpha,
396                 beta,
397                 tensor_d);
398             return;
399         }
400     }
401     if constexpr (EnableColumnMajorColumnMajorLayouts) {
402         if (!tensor_a_row_major && !tensor_b_row_major) {
403             spgemm_cutlass<
404                 ElementInputA,
405                 ElementInputB,
406                 ElementOutput,
407                 ElementAccumulator,
408                 ThreadblockShape,
409                 WarpShape,
410                 InstructionShape,
411                 cutlass::layout::ColumnMajor,
412                 cutlass::layout::ColumnMajor,
413                 use_tensor_c>(
414                 tensor_a,
415                 tensor_a_stride,
416                 tensor_b,
417                 tensor_b_stride,
418                 tensor_c,
419                 tensor_e,
420                 alpha,
421                 beta,
422                 tensor_d);
423             return;
424         }
425     }
426 
427     AT_ERROR(__func__, "_dispatch_layouts: Combination of ",
428              tensor_a_row_major ? "row-major" : "column_major", " and ",
429              tensor_b_row_major ? "row-major" : "column_major",
430              " layouts for input tensors is not supported");
431 }
432 
433 // Dispatch according to the tensor_c tensor being provided or not.
434 template <
435     typename ElementInputA,
436     typename ElementInputB,
437     typename ElementOutput,
438     typename ElementAccumulator,
439     typename ThreadblockShape,
440     typename WarpShape,
441     typename InstructionShape,
442     bool EnableRowMajorRowMajorLayouts,
443     bool EnableRowMajorColumnMajorLayouts,
444     bool EnableColumnMajorRowMajorLayouts,
445     bool EnableColumnMajorColumnMajorLayouts>
446 void spgemm_cutlass_dispatch_layouts_tensor_c(
447     const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
448     const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta,
449     Tensor& tensor_d) {
450     if (tensor_c.numel() > 0) {
451         spgemm_cutlass_dispatch_layouts<
452             ElementInputA,
453             ElementInputB,
454             ElementOutput,
455             ElementAccumulator,
456             ThreadblockShape,
457             WarpShape,
458             InstructionShape,
459             EnableRowMajorRowMajorLayouts,
460             EnableRowMajorColumnMajorLayouts,
461             EnableColumnMajorRowMajorLayouts,
462             EnableColumnMajorColumnMajorLayouts,
463             true>(
464             tensor_a,
465             tensor_b,
466             tensor_c,
467             tensor_e,
468             alpha,
469             beta,
470             tensor_d);
471     } else {
472         spgemm_cutlass_dispatch_layouts<
473             ElementInputA,
474             ElementInputB,
475             ElementOutput,
476             ElementAccumulator,
477             ThreadblockShape,
478             WarpShape,
479             InstructionShape,
480             EnableRowMajorRowMajorLayouts,
481             EnableRowMajorColumnMajorLayouts,
482             EnableColumnMajorRowMajorLayouts,
483             EnableColumnMajorColumnMajorLayouts,
484             false>(
485             tensor_a,
486             tensor_b,
487             tensor_c,
488             tensor_e,
489             alpha,
490             beta,
491             tensor_d);
492     }
493 }
494 #endif
495 
496 // Perform multiply-add operation, using corresponding CUTLASS
497 // sparse GEMM kernel, to given arguments:
498 //     result = alpha * mat1 @ mat2 + beta * input
499 // The "mat2" tensor is a dense tensor, while the "mat1" tensor is a
500 // sparse semi-structured matrix.  The "input" tensor is optional; if
501 // provided, it should be a vector, with the number of elements equal
502 // to the number of rows of "mat1" matrix.  It is assumed that "mat1"
503 // and "mat2" are 2D tensors, supplied either in row-major or
504 // column-major layouts (different layouts between these two tensors
505 // are OK, but not all combinations of formats are supported for some
506 // datatypes of these matrices).  The "mat1_meta" argument contains
507 // sparse semi-strucutred metadata.
508 //
509 // There exists numerous limitations of CUTLASS sparse GEMM kernel,
510 // with regards to sizes and alignments of input tensors, their
511 // layouts and datatypes, and so on; this is the reason for large
512 // number of checks throughout the code.
513 //
514 // TODO: The "input" tensor has to be a vector, such that it could be
515 // broadcasted to columns of mat1 * mat2.  The case of broadcasting to
516 // rows of mat1 * mat2 could be also supported, if "input" tensor is a
517 // vector of corresponding length; and same for the case when "input"
518 // tensor is a matrix of same size as mat1 * mat2 product.  If these
519 // updates made here, then remember to update corresponding bits in
520 // the Inductor code that are handling meta registrations and
521 // lowerings of aten._sparse_semi_structured_mm and
522 // aten._sparse_semi_structured_addmm operators.
sparse_semi_structured_mad_op(const Tensor & mat1,const Tensor & mat1_meta,const Tensor & mat2,const std::optional<Tensor> & input_opt,const Scalar & alpha,const Scalar & beta,const std::optional<c10::ScalarType> out_dtype_opt)523 Tensor sparse_semi_structured_mad_op(
524       const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2,
525       const std::optional<Tensor>& input_opt, const Scalar& alpha,
526       const Scalar& beta, const std::optional<c10::ScalarType> out_dtype_opt) {
527 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
528     AT_ERROR(__func__, " : CUTLASS not supported");
529     return Tensor{};
530 #else
531     // No need to check that all tensors are on CUDA device, as this
532     // is provided by dispatch.
533 
534     const auto& input = input_opt.value_or(Tensor{});
535     const auto out_dtype = out_dtype_opt.value_or(mat2.scalar_type());
536 
537     // For now, only CC 8.x devices are supported.
538     const auto dprops = at::cuda::getCurrentDeviceProperties();
539     const auto is_sm8x = dprops->major == 8;
540     TORCH_CHECK(is_sm8x,
541                 __func__, " : Supported only on GPUs with compute capability "
542                 "8.x");
543 
544     // Validate datatypes of input tensors.
545     TORCH_CHECK(mat2.dtype() == at::kChar ||
546                 mat2.dtype() == at::kHalf ||
547                 mat2.dtype() == at::kBFloat16 ||
548                 mat2.dtype() == at::kFloat,
549                 __func__, " : The mat2 datatype ", mat2.dtype(),
550                 " is not supported");
551     TORCH_CHECK(mat1.dtype() == mat2.dtype(),
552                 __func__, " : Expected mat1 datatype ", mat2.dtype(),
553                 ", but got ", mat1.dtype());
554     if (input.numel() != 0) {
555         TORCH_CHECK(input.dtype() == out_dtype,
556                     __func__, " : Expected input datatype ", out_dtype,
557                     ", but got ", input.dtype());
558     }
559 
560     // Validate layouts of input tensors.
561     TORCH_CHECK(mat1.layout() == Layout::Strided,
562                 __func__, " : Expected mat1 argument to be strided, but got "
563                 "layout ", mat1.layout());
564     TORCH_CHECK(mat1.dim() == 2,
565                 __func__, " : Expected mat1 argument to be 2D tensor, got ",
566                 mat1.dim(), " dims");
567     const auto strides_a = mat1.strides();
568     TORCH_CHECK(strides_a[0] == 1 || strides_a[1] == 1,
569                 __func__, " : Invalid strides for mat1 argument: row stride = ",
570                 strides_a[0], ", column stride = ", strides_a[1]);
571     TORCH_CHECK(mat2.layout() == Layout::Strided,
572                 __func__, " : Expected mat2 argument to be "
573                 "strided, but got layout ", mat2.layout());
574     TORCH_CHECK(mat2.dim() == 2,
575                 __func__, " : Expected mat2 argument to be 2D tensor, got ",
576                 mat2.dim(), " dims");
577     const auto strides_b = mat2.strides();
578     TORCH_CHECK(strides_b[0] == 1 || strides_b[1] == 1,
579                 __func__, " : Invalid strides for mat2 argument: row stride = ",
580                 strides_b[0], ", column stride = ", strides_b[1]);
581     if (input.numel() != 0) {
582         TORCH_CHECK(input.layout() == Layout::Strided,
583                     __func__, " : Expected input argument to be strided, but "
584                     "got layout ", input.layout());
585         TORCH_CHECK(input.dim() == 1,
586                     __func__, " : Expected input argument to be 1D tensor, "
587                     "got ", input.dim(), " dims");
588     }
589 
590     // Validate sizes of input tensors.
591     TORCH_CHECK(mat1.size(1) == mat2.size(0) / 2,
592                 __func__, " : Expected mat1 argument to have ",
593                 mat2.size(0) / 2, " columns, but got ", mat1.size(1));
594     if (input.numel() != 0) {
595         TORCH_CHECK(input.size(0) == mat1.size(0),
596                     __func__, " : Expected input argument to have ",
597                     mat1.size(0), " elements, but got ", input.size(0));
598     }
599 
600     // Introduce alias names for arguments, according to the CUTLASS
601     // naming conventions.
602     const auto& tensor_a = mat1;
603     const auto& tensor_b = mat2;
604     const auto& tensor_c = input;
605     const auto& tensor_e = mat1_meta;
606 
607     // Create output tensor.
608     Tensor tensor_d =
609         tensor_b.new_empty({tensor_a.size(0), tensor_b.size(1)},
610                            at::TensorOptions().dtype(out_dtype));
611 
612     // Call wrapper function for CUTLASS sparse GEMM, dispatching on
613     // the input datatype, and then on input tensors layouts.
614     // According to the input tensors datatypes and layouts,
615     // corresponding template arguments are supplied for instantiating
616     // the wrapper function.  The tile sizes template arguments are
617     // selected according to the CUTLASS profiler results, for number
618     // of runs.
619     AT_DISPATCH_SWITCH(
620         tensor_a.scalar_type(),
621         "sparse_semi_structured_mad_op",
622         AT_DISPATCH_CASE(
623             at::ScalarType::Char,
624             [&]() {
625                 using ElementInputA = int8_t;
626                 using ElementInputB = int8_t;
627                 using ElementAccumulator = int32_t;
628                 using ThreadblockShape =
629                     cutlass::gemm::GemmShape<128, 128, 128>;
630                 using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
631                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
632                 const auto EnableRowMajorRowMajorLayouts = false;
633                 const auto EnableRowMajorColumnMajorLayouts = true;
634                 const auto EnableColumnMajorRowMajorLayouts = false;
635                 const auto EnableColumnMajorColumnMajorLayouts = false;
636                 if (out_dtype == at::kInt) {
637                   using ElementOutput = int32_t;
638                   spgemm_cutlass_dispatch_layouts_tensor_c<
639                       ElementInputA,
640                       ElementInputB,
641                       ElementOutput,
642                       ElementAccumulator,
643                       ThreadblockShape,
644                       WarpShape,
645                       InstructionShape,
646                       EnableRowMajorRowMajorLayouts,
647                       EnableRowMajorColumnMajorLayouts,
648                       EnableColumnMajorRowMajorLayouts,
649                       EnableColumnMajorColumnMajorLayouts>(
650                       tensor_a,
651                       tensor_b,
652                       tensor_c,
653                       tensor_e,
654                       alpha,
655                       beta,
656                       tensor_d);
657                 } else if (out_dtype == at::kChar) {
658                   using ElementOutput = int8_t;
659                   spgemm_cutlass_dispatch_layouts_tensor_c<
660                       ElementInputA,
661                       ElementInputB,
662                       ElementOutput,
663                       ElementAccumulator,
664                       ThreadblockShape,
665                       WarpShape,
666                       InstructionShape,
667                       EnableRowMajorRowMajorLayouts,
668                       EnableRowMajorColumnMajorLayouts,
669                       EnableColumnMajorRowMajorLayouts,
670                       EnableColumnMajorColumnMajorLayouts>(
671                       tensor_a,
672                       tensor_b,
673                       tensor_c,
674                       tensor_e,
675                       alpha,
676                       beta,
677                       tensor_d);
678                 }
679             })
680         AT_DISPATCH_CASE(
681             at::ScalarType::Half,
682             [&]() {
683                 using ElementInputA = cutlass::half_t;
684                 using ElementInputB = cutlass::half_t;
685                 using ElementOutput = cutlass::half_t;
686                 using ElementAccumulator = float;
687                 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
688                 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
689                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
690                 const auto EnableRowMajorRowMajorLayouts = true;
691                 const auto EnableRowMajorColumnMajorLayouts = true;
692                 const auto EnableColumnMajorRowMajorLayouts = true;
693                 const auto EnableColumnMajorColumnMajorLayouts = true;
694                 spgemm_cutlass_dispatch_layouts_tensor_c<
695                     ElementInputA,
696                     ElementInputB,
697                     ElementOutput,
698                     ElementAccumulator,
699                     ThreadblockShape,
700                     WarpShape,
701                     InstructionShape,
702                     EnableRowMajorRowMajorLayouts,
703                     EnableRowMajorColumnMajorLayouts,
704                     EnableColumnMajorRowMajorLayouts,
705                     EnableColumnMajorColumnMajorLayouts>(
706                     tensor_a,
707                     tensor_b,
708                     tensor_c,
709                     tensor_e,
710                     alpha,
711                     beta,
712                     tensor_d);
713             })
714             AT_DISPATCH_CASE(
715             at::ScalarType::BFloat16,
716             [&]() {
717                 using ElementInputA = cutlass::bfloat16_t;
718                 using ElementInputB = cutlass::bfloat16_t;
719                 using ElementOutput = cutlass::bfloat16_t;
720                 using ElementAccumulator = float;
721                 using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
722                 using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
723                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
724                 const auto EnableRowMajorRowMajorLayouts = true;
725                 const auto EnableRowMajorColumnMajorLayouts = true;
726                 const auto EnableColumnMajorRowMajorLayouts = true;
727                 const auto EnableColumnMajorColumnMajorLayouts = true;
728                 spgemm_cutlass_dispatch_layouts_tensor_c<
729                     ElementInputA,
730                     ElementInputB,
731                     ElementOutput,
732                     ElementAccumulator,
733                     ThreadblockShape,
734                     WarpShape,
735                     InstructionShape,
736                     EnableRowMajorRowMajorLayouts,
737                     EnableRowMajorColumnMajorLayouts,
738                     EnableColumnMajorRowMajorLayouts,
739                     EnableColumnMajorColumnMajorLayouts>(
740                     tensor_a,
741                     tensor_b,
742                     tensor_c,
743                     tensor_e,
744                     alpha,
745                     beta,
746                     tensor_d);
747             })
748             AT_DISPATCH_CASE(
749             at::ScalarType::Float,
750             [&]() {
751                 using ElementInputA = float;
752                 using ElementInputB = float;
753                 using ElementOutput = float;
754                 using ElementAccumulator = float;
755                 using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
756                 using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
757                 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
758                 const auto EnableRowMajorRowMajorLayouts = true;
759                 const auto EnableRowMajorColumnMajorLayouts = true;
760                 const auto EnableColumnMajorRowMajorLayouts = true;
761                 const auto EnableColumnMajorColumnMajorLayouts = true;
762                 spgemm_cutlass_dispatch_layouts_tensor_c<
763                     ElementInputA,
764                     ElementInputB,
765                     ElementOutput,
766                     ElementAccumulator,
767                     ThreadblockShape,
768                     WarpShape,
769                     InstructionShape,
770                     EnableRowMajorRowMajorLayouts,
771                     EnableRowMajorColumnMajorLayouts,
772                     EnableColumnMajorRowMajorLayouts,
773                     EnableColumnMajorColumnMajorLayouts>(
774                     tensor_a,
775                     tensor_b,
776                     tensor_c,
777                     tensor_e,
778                     alpha,
779                     beta,
780                     tensor_d);
781             }));
782 
783     return tensor_d;
784 #endif
785 }
786 
787 // Implementation of aten._sparse_semi_structured_mm operator.
_sparse_semi_structured_mm(const Tensor & mat1,const Tensor & mat1_meta,const Tensor & mat2,const std::optional<c10::ScalarType> out_dtype_opt)788 Tensor _sparse_semi_structured_mm(
789       const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2,
790       const std::optional<c10::ScalarType> out_dtype_opt) {
791     return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2,
792                                          std::optional<Tensor>(), 1, 0,
793                                          out_dtype_opt);
794 }
795 
796 // Implementation of aten._sparse_semi_structured_addmm operator.
_sparse_semi_structured_addmm(const Tensor & input,const Tensor & mat1,const Tensor & mat1_meta,const Tensor & mat2,const Scalar & alpha,const Scalar & beta,const std::optional<c10::ScalarType> out_dtype_opt)797 Tensor _sparse_semi_structured_addmm(
798       const Tensor& input, const Tensor& mat1, const Tensor& mat1_meta,
799       const Tensor& mat2, const Scalar& alpha, const Scalar& beta,
800       const std::optional<c10::ScalarType> out_dtype_opt) {
801     return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2, input, alpha,
802                                          beta, out_dtype_opt);
803 }
804 
805 } // namespace at::native
806 
807 // Following is just for testing purposes.
808 namespace at::native {
809 
810 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
811 #else
812 // Copied from tools/util/include/host_reorder.h, from CUTLASS source
813 // tree.  This is for simplicity - namely, this file is not under
814 // include/cutlass in this tree, as other CUTLASS include files
815 // needed, so it would require changing PyTorch CMake configuration;
816 // furthermore, including this file produces build errors in PyTorch
817 // at the moment.
818 template <typename Element, typename LayoutDest, typename LayoutSrc>
819 static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
820                          cutlass::TensorRef<Element, LayoutSrc> src,
821                          const int problem_size_m, const int problem_size_k) {
822   for (int m = 0; m < problem_size_m; m++) {
823     for (int k = 0; k < problem_size_k; k++) {
824       // First reorder the rows.
825       int group = (sizeof(Element) == 2) ? 32 : 16;
826       int interweave = (sizeof(Element) == 2) ? 4 : 2;
827 
828       int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
829       int dest_col = k;
830 
831       // Next swizzle the 2x2 blocks from Z to N.
832       if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
833         ++dest_row;
834         --dest_col;
835       } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
836         --dest_row;
837         ++dest_col;
838       }
839 
840       dest.at({dest_row, dest_col}) = src.at({m, k});
841     }
842   }
843 }
844 #endif
845 
846 std::tuple<Tensor, Tensor>
_to_sparse_semi_structured(const Tensor & dense)847 _to_sparse_semi_structured(const Tensor& dense) {
848 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
849   AT_ERROR(__func__, " : CUTLASS not supported");
850   return std::make_tuple(Tensor{}, Tensor{});
851 #else
852   // Check dimensions of the dense matrix.
853   TORCH_CHECK(dense.dim() == 2,
854               __func__, " : Expected dense argument to be 2D tensor, got ",
855               dense.dim(), " dims");
856 
857   // Determine PyTorch datatype for the metadata matrix.
858   auto meta_dtype = at::kChar;
859   auto ksparse = 0;
860   auto dense_elems_per_meta_elem = 0;
861   if (dense.dtype() == at::kChar) {
862     meta_dtype = at::kInt;
863     ksparse = 4;
864     dense_elems_per_meta_elem = 32;
865   } else if (dense.dtype() == at::kHalf || dense.dtype() == at::kBFloat16) {
866     meta_dtype = at::kShort;
867     ksparse = 4;
868     dense_elems_per_meta_elem = 16;
869   } else if (dense.dtype() == at::kFloat) {
870     meta_dtype = at::kShort;
871     ksparse = 2;
872     dense_elems_per_meta_elem = 8;
873   } else {
874     AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ",
875              dense.dtype(), " encountered");
876   }
877 
878   const auto dense_nrows = dense.size(0);
879   const auto dense_ncols = dense.size(1);
880 
881   if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) {
882     AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must "
883              "be divisible by ", (meta_dtype == at::kShort ? 32 : 16),
884              ", but it is ", dense_nrows);
885   }
886   if (dense_ncols % dense_elems_per_meta_elem != 0) {
887     AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix "
888              "must be divisible by ", dense_elems_per_meta_elem, ", but it is ",
889              dense_ncols);
890   }
891 
892   const auto dense_cpu = dense.to("cpu");
893 
894   const auto mask_cpu = dense_cpu != at::zeros({1}, dense_cpu.options());
895 
896   const auto sparse_cpu =
897     dense_cpu.masked_select(mask_cpu).view({dense_nrows, dense_ncols / 2});
898 
899   const auto meta_nrows = dense_nrows;
900   const auto meta_ncols = dense_ncols / dense_elems_per_meta_elem;
901   auto meta_cpu = dense_cpu.new_empty({meta_nrows, meta_ncols},
902                                       at::TensorOptions().dtype(meta_dtype));
903 
904   auto* mask_cpu_ptr = mask_cpu.data_ptr<bool>();
905   for (auto i = 0; i < meta_nrows; ++i) {
906     for (auto j = 0; j < meta_ncols; ++j) {
907       uint64_t meta_val = 0;
908       for (auto k = 0; k < dense_elems_per_meta_elem / ksparse; ++k, mask_cpu_ptr += ksparse) {
909         const auto mask_elems =
910           (ksparse == 4) ? std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[1],
911                                            mask_cpu_ptr[2], mask_cpu_ptr[3])
912                          : std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[0],
913                                            mask_cpu_ptr[1], mask_cpu_ptr[1]);
914         auto meta_quadruple = 0;
915         if (mask_elems == std::make_tuple(1, 1, 0, 0)) {
916           meta_quadruple = 4; // 0100
917         } else if (mask_elems == std::make_tuple(1, 0, 1, 0)) {
918           meta_quadruple = 8; // 1000
919         } else if (mask_elems == std::make_tuple(0, 1, 1, 0)) {
920           meta_quadruple = 9; // 1001
921         } else if (mask_elems == std::make_tuple(1, 0, 0, 1)) {
922           meta_quadruple = 12; // 1100
923         } else if (mask_elems == std::make_tuple(0, 1, 0, 1)) {
924           meta_quadruple = 13; // 1101
925         } else if (mask_elems == std::make_tuple(0, 0, 1, 1)) {
926           meta_quadruple = 14; // 1110
927         } else {
928           AT_ERROR("_to_sparse_semi_structured: dense argument does not match ",
929                    (dense.dtype() != at::kFloat) ? "2:4" : "1:2",
930                    "sparsity pattern");
931         }
932         meta_val = meta_val | (meta_quadruple << (4 * k));
933       }
934       const auto idx = i * meta_ncols + j;
935       if (meta_dtype == at::kShort) {
936         using MetaElement = int16_t;
937         const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
938         meta_cpu_ptr[idx] = (MetaElement)meta_val;
939       } else if (meta_dtype == at::kInt) {
940         using MetaElement = int32_t;
941         const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
942         meta_cpu_ptr[idx] = (MetaElement)meta_val;
943       }
944     }
945   }
946 
947   auto meta_reordered_cpu = meta_cpu.new_empty({meta_nrows, meta_ncols});
948   using MetaLayout = cutlass::layout::RowMajor;
949   using MetaReorderedLayout = cutlass::layout::ColumnMajorInterleaved<2>;
950   if (meta_dtype == at::kShort) {
951     using MetaElement = int16_t;
952     auto meta_cpu_ref =
953       cutlass::TensorRef<MetaElement, MetaLayout>(
954           meta_cpu.data_ptr<MetaElement>(),
955           MetaLayout::packed({meta_nrows, meta_ncols}));
956     auto meta_reordered_cpu_ref =
957       cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
958           meta_reordered_cpu.data_ptr<MetaElement>(),
959           MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
960     reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
961   } else if (meta_dtype == at::kInt) {
962     using MetaElement = int32_t;
963     auto meta_cpu_ref =
964       cutlass::TensorRef<MetaElement, MetaLayout>(
965           meta_cpu.data_ptr<MetaElement>(),
966           MetaLayout::packed({meta_nrows, meta_ncols}));
967     auto meta_reordered_cpu_ref =
968       cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
969           meta_reordered_cpu.data_ptr<MetaElement>(),
970           MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
971     reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
972   }
973 
974   return std::make_tuple(sparse_cpu.to(dense.device()),
975                          meta_reordered_cpu.to(dense.device()));
976 #endif
977 }
978 
979 }  // namespace at::native
980