1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 /*! \file 9 \brief Cutlass provides helper template functions to figure out the right 10 data structures to instantiate to run a GEMM with various parameters (see 11 `cutlass/gemm/threadblock/default_mma.h`). However, due to template 12 instantiation priority rules, it will only create an MmaMultiStage with 13 kStages=3 (otherwise creates an MmePipelined - which is not compatible with 14 FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, 15 so we just copy-pasted some code from `default_mma.h` and 16 `default_mma_core.h` files and wrapped this template to allow our use case. 17 18 This is really only for the FastF32 case - aka using TensorCores with fp32. 19 */ 20 21 #pragma once 22 23 #include <cutlass/gemm/threadblock/default_mma.h> 24 #include <cutlass/gemm/threadblock/default_mma_core_simt.h> 25 #include <cutlass/gemm/threadblock/default_mma_core_sm70.h> 26 #include <cutlass/gemm/threadblock/default_mma_core_sm75.h> 27 #include <cutlass/gemm/threadblock/default_mma_core_sm80.h> 28 29 namespace cutlass { 30 namespace gemm { 31 namespace threadblock { 32 33 template < 34 /// Element type for A matrix operand 35 typename ElementA, 36 /// Layout type for A matrix operand 37 typename LayoutA, 38 /// Access granularity of A matrix in units of elements 39 int kAlignmentA, 40 /// Element type for B matrix operand 41 typename ElementB, 42 /// Layout type for B matrix operand 43 typename LayoutB, 44 /// Access granularity of B matrix in units of elements 45 int kAlignmentB, 46 /// Element type for internal accumulation 47 typename ElementAccumulator, 48 /// Layout type for C and D matrix operand 49 typename LayoutC, 50 /// Operator class tag 51 typename OperatorClass, 52 /// Tag indicating architecture to tune for 53 typename ArchTag, 54 /// Threadblock-level tile size (concept: GemmShape) 55 typename ThreadblockShape, 56 /// Warp-level tile size (concept: GemmShape) 57 typename WarpShape, 58 /// Instruction-level tile size (concept: GemmShape) 59 typename InstructionShape, 60 /// Number of stages used in the pipelined mainloop 61 int Stages, 62 /// Operation performed by GEMM 63 typename Operator, 64 typename Enable_ = void> 65 struct FindDefaultMma { 66 static constexpr bool AccumulatorsInRowMajor = false; 67 static constexpr SharedMemoryClearOption SharedMemoryClear = 68 SharedMemoryClearOption::kNone; 69 using DefaultMma = cutlass::gemm::threadblock::DefaultMma< 70 ElementA, 71 LayoutA, 72 kAlignmentA, 73 ElementB, 74 LayoutB, 75 kAlignmentB, 76 ElementAccumulator, 77 LayoutC, 78 OperatorClass, 79 ArchTag, 80 ThreadblockShape, 81 WarpShape, 82 InstructionShape, 83 Stages, 84 Operator, 85 AccumulatorsInRowMajor, 86 SharedMemoryClear>; 87 }; 88 89 /// Specialization for sm80 / FastF32 / multistage with kStages=2 90 template < 91 typename ElementA_, 92 /// Layout type for A matrix operand 93 typename LayoutA_, 94 /// Access granularity of A matrix in units of elements 95 int kAlignmentA, 96 typename ElementB_, 97 /// Layout type for B matrix operand 98 typename LayoutB_, 99 /// Access granularity of B matrix in units of elements 100 int kAlignmentB, 101 typename ElementAccumulator, 102 /// Threadblock-level tile size (concept: GemmShape) 103 typename ThreadblockShape, 104 /// Warp-level tile size (concept: GemmShape) 105 typename WarpShape, 106 /// Instruction-level tile size (concept: GemmShape) 107 typename InstructionShape, 108 int kStages, 109 typename Operator> 110 struct FindDefaultMma< 111 ElementA_, 112 LayoutA_, 113 kAlignmentA, 114 ElementB_, 115 LayoutB_, 116 kAlignmentB, 117 ElementAccumulator, 118 layout::RowMajor, 119 arch::OpClassTensorOp, 120 arch::Sm80, 121 ThreadblockShape, 122 WarpShape, 123 InstructionShape, 124 kStages, 125 Operator, 126 typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { 127 using LayoutC = layout::RowMajor; 128 using OperatorClass = arch::OpClassTensorOp; 129 using ArchTag = arch::Sm80; 130 131 using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< 132 ElementA_, 133 LayoutA_, 134 kAlignmentA, 135 ElementB_, 136 LayoutB_, 137 kAlignmentB, 138 ElementAccumulator, 139 LayoutC, 140 OperatorClass, 141 ArchTag, 142 ThreadblockShape, 143 WarpShape, 144 InstructionShape, 145 3, 146 Operator>; 147 struct DefaultMma : DefaultMma_ { 148 using MmaCore_ = typename DefaultMma_::MmaCore; 149 // Define the threadblock-scoped multistage matrix multiply 150 using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< 151 typename MmaCore_::Shape, 152 typename DefaultMma_::IteratorA, 153 typename MmaCore_::SmemIteratorA, 154 MmaCore_::kCacheOpA, 155 typename DefaultMma_::IteratorB, 156 typename MmaCore_::SmemIteratorB, 157 MmaCore_::kCacheOpB, 158 ElementAccumulator, 159 LayoutC, 160 typename MmaCore_::MmaPolicy, 161 kStages>; 162 }; 163 }; 164 165 } // namespace threadblock 166 } // namespace gemm 167 } // namespace cutlass 168