xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/default_dq_mma.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h>
4 #include <ATen/native/cuda/cutlass_extensions/interleaved_numeric_conversion.h>
5 
6 namespace cutlass {
7 namespace gemm {
8 namespace threadblock {
9 ////////////////////////////////////////////////////////////////////////////////
10 
11 // We need to distinguish here, since we want volta support. It is too much effort
12 // to write shared memory iterators that are probably needed for volta to function
13 // properly. As a result, we allow converters both after the LDG (for volta) and after
14 // the LDS for Turing+.
15 template<
16     /// Iterator for B matrix in global memory
17     typename IteratorB,
18     /// Warp level Mma
19     typename MmaOperator,
20     /// Math operation perform by warp level operator
21     typename MathOperator>
22 struct SetConverters {
23 };
24 
25 // Dequantize after LDG, so set transforms accordingly
26 template<
27     /// Iterator for B matrix in global memory
28     typename IteratorB,
29     /// Mma Policy
30     typename MmaOperator>
31 struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
32     using TransformAfterLDG =
33         FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
34                                                       typename IteratorB::Element,
35                                                       IteratorB::Fragment::kElements>;
36 
37     using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
38                                                     typename MmaOperator::ArchMmaOperator::ElementB,
39                                                     MmaOperator::FragmentB::kElements>;
40 };
41 
42 // Dequantize after LDS, so set transforms accordingly
43 
44 template<
45     /// Iterator for B matrix in global memory
46     typename IteratorB,
47     /// Mma Policy
48     typename MmaOperator>
49 struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA> {
50     using TransformAfterLDG =
51         NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element, IteratorB::Fragment::kElements>;
52 
53     using TransformAfterLDS =
54         FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
55                                                       typename TransformAfterLDG::result_type::Element,
56                                                       MmaOperator::FragmentB::kElements>;
57 };
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 
61 template<
62     /// Element type for A matrix operand
63     typename ElementA_,
64     /// Layout type for A matrix operand
65     typename LayoutA_,
66     /// Access granularity of A matrix in units of elements
67     int kAlignmentA,
68     /// Element type for B matrix operand
69     typename ElementB_,
70     /// Layout type for B matrix operand
71     typename LayoutB_,
72     /// Access granularity of B matrix in units of elements
73     int kAlignmentB,
74     /// Element type for the input scale
75     typename ElementScale_,
76     /// Layout for the scale operand
77     typename LayoutScale_,
78     /// Access granularity of Scales in unit of elements
79     int kAlignmentScale,
80     /// Element type for internal accumulation
81     typename ElementAccumulator_,
82     /// Layout type for C and D matrix operands
83     typename LayoutC_,
84     /// Operator class tag
85     typename OperatorClass_,
86     /// Tag indicating architecture to tune for
87     typename ArchTag_,
88     /// Threadblock-level tile size (concept: GemmShape)
89     typename ThreadblockShape_,
90     /// Warp-level tile size (concept: GemmShape)
91     typename WarpShape_,
92     /// Instruction-level tile size (concept: GemmShape)
93     typename InstructionShape_,
94     /// Number of stages used in the pipelined mainloop
95     int Stages,
96     /// Operation performed by GEMM
97     typename Operator_,
98     /// Use zfill or predicate for out-of-bound cp.async
99     SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
100     ///
101     typename Enable = void>
102 struct DqMma;
103 
104 }  // namespace threadblock
105 }  // namespace gemm
106 }  // namespace cutlass
107