1 #pragma once 2 3 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h> 4 #include <ATen/native/cuda/cutlass_extensions/interleaved_numeric_conversion.h> 5 6 namespace cutlass { 7 namespace gemm { 8 namespace threadblock { 9 //////////////////////////////////////////////////////////////////////////////// 10 11 // We need to distinguish here, since we want volta support. It is too much effort 12 // to write shared memory iterators that are probably needed for volta to function 13 // properly. As a result, we allow converters both after the LDG (for volta) and after 14 // the LDS for Turing+. 15 template< 16 /// Iterator for B matrix in global memory 17 typename IteratorB, 18 /// Warp level Mma 19 typename MmaOperator, 20 /// Math operation perform by warp level operator 21 typename MathOperator> 22 struct SetConverters { 23 }; 24 25 // Dequantize after LDG, so set transforms accordingly 26 template< 27 /// Iterator for B matrix in global memory 28 typename IteratorB, 29 /// Mma Policy 30 typename MmaOperator> 31 struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> { 32 using TransformAfterLDG = 33 FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB, 34 typename IteratorB::Element, 35 IteratorB::Fragment::kElements>; 36 37 using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB, 38 typename MmaOperator::ArchMmaOperator::ElementB, 39 MmaOperator::FragmentB::kElements>; 40 }; 41 42 // Dequantize after LDS, so set transforms accordingly 43 44 template< 45 /// Iterator for B matrix in global memory 46 typename IteratorB, 47 /// Mma Policy 48 typename MmaOperator> 49 struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA> { 50 using TransformAfterLDG = 51 NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element, IteratorB::Fragment::kElements>; 52 53 using TransformAfterLDS = 54 FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB, 55 typename TransformAfterLDG::result_type::Element, 56 MmaOperator::FragmentB::kElements>; 57 }; 58 59 //////////////////////////////////////////////////////////////////////////////// 60 61 template< 62 /// Element type for A matrix operand 63 typename ElementA_, 64 /// Layout type for A matrix operand 65 typename LayoutA_, 66 /// Access granularity of A matrix in units of elements 67 int kAlignmentA, 68 /// Element type for B matrix operand 69 typename ElementB_, 70 /// Layout type for B matrix operand 71 typename LayoutB_, 72 /// Access granularity of B matrix in units of elements 73 int kAlignmentB, 74 /// Element type for the input scale 75 typename ElementScale_, 76 /// Layout for the scale operand 77 typename LayoutScale_, 78 /// Access granularity of Scales in unit of elements 79 int kAlignmentScale, 80 /// Element type for internal accumulation 81 typename ElementAccumulator_, 82 /// Layout type for C and D matrix operands 83 typename LayoutC_, 84 /// Operator class tag 85 typename OperatorClass_, 86 /// Tag indicating architecture to tune for 87 typename ArchTag_, 88 /// Threadblock-level tile size (concept: GemmShape) 89 typename ThreadblockShape_, 90 /// Warp-level tile size (concept: GemmShape) 91 typename WarpShape_, 92 /// Instruction-level tile size (concept: GemmShape) 93 typename InstructionShape_, 94 /// Number of stages used in the pipelined mainloop 95 int Stages, 96 /// Operation performed by GEMM 97 typename Operator_, 98 /// Use zfill or predicate for out-of-bound cp.async 99 SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, 100 /// 101 typename Enable = void> 102 struct DqMma; 103 104 } // namespace threadblock 105 } // namespace gemm 106 } // namespace cutlass 107