1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Template for a double-buffered threadblock-scoped GEMM kernel. 34 */ 35 36 #pragma once 37 38 #include <cutlass/aligned_buffer.h> 39 #include <cutlass/arch/memory.h> 40 #include <cutlass/array.h> 41 #include <cutlass/cutlass.h> 42 #include <cutlass/gemm/gemm.h> 43 #include <cutlass/gemm/threadblock/mma_base.h> 44 #include <cutlass/matrix_shape.h> 45 #include <cutlass/numeric_types.h> 46 47 //////////////////////////////////////////////////////////////////////////////// 48 49 namespace cutlass { 50 namespace gemm { 51 namespace threadblock { 52 53 //////////////////////////////////////////////////////////////////////////////// 54 55 /// Structure to compute the matrix product targeting CUDA cores and SIMT math 56 /// instructions. 57 template < 58 /// Size of the Gemm problem - concept: gemm::GemmShape<> 59 typename Shape_, 60 /// Policy describing tuning details (concept: MmaPolicy) 61 typename Policy_, 62 /// Number of stages, 63 int Stages, 64 /// Used for partial specialization 65 typename Enable = bool> 66 class CustomMmaBase { 67 public: 68 ///< Size of the Gemm problem - concept: gemm::GemmShape<> 69 using Shape = Shape_; 70 71 ///< Policy describing tuning details 72 using Policy = Policy_; 73 74 // 75 // Dependent types 76 // 77 78 /// Warp-level Mma 79 using Operator = typename Policy::Operator; 80 81 /// Shape describing the overall GEMM computed from shared memory 82 /// by each warp. 83 using WarpGemm = typename Policy::Operator::Shape; 84 85 /// Shape describing the number of warps filling the CTA 86 using WarpCount = GemmShape< 87 Shape::kM / WarpGemm::kM, 88 Shape::kN / WarpGemm::kN, 89 Shape::kK / WarpGemm::kK>; 90 91 /// Number of warp-level GEMM oeprations 92 static int const kWarpGemmIterations = 93 (WarpGemm::kK / Operator::Policy::MmaShape::kK); 94 95 /// Number of stages 96 static int const kStages = Stages; 97 98 // 99 // Nested structs 100 // 101 102 /// Shared storage object needed by threadblock-scoped GEMM 103 template <typename Element, typename OperandShape, typename OperandLayout> 104 struct OperandSharedStorage { 105 AlignedBuffer<Element, OperandShape::kCount> buffer; 106 using TensorRef = TensorRef<Element, OperandLayout>; 107 108 CUTLASS_DEVICE LayoutOperandSharedStorage109 static OperandLayout Layout() { 110 return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); 111 } 112 113 /// Returns a TensorRef to the operand 114 CUTLASS_HOST_DEVICE refOperandSharedStorage115 TensorRef ref() { 116 return TensorRef{buffer.data(), Layout()}; 117 } 118 }; 119 120 /// Shape of the A matrix operand in shared memory 121 using ShapeA = MatrixShape< 122 Shape::kM + Policy::SmemPaddingA::kRow, 123 Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; 124 125 /// Shape of the B matrix operand in shared memory 126 using ShapeB = MatrixShape< 127 Shape::kK * kStages + Policy::SmemPaddingB::kRow, 128 Shape::kN + Policy::SmemPaddingB::kColumn>; 129 130 using SharedStorageA = OperandSharedStorage< 131 typename Operator::ElementA, 132 ShapeA, 133 typename Operator::LayoutA>; 134 using SharedStorageB = OperandSharedStorage< 135 typename Operator::ElementB, 136 ShapeB, 137 typename Operator::LayoutB>; 138 using TensorRefA = typename SharedStorageA::TensorRef; 139 using TensorRefB = typename SharedStorageB::TensorRef; 140 141 struct SharedStorage { 142 /// Buffer for A operand 143 SharedStorageA operand_A; 144 145 /// Buffer for B operand 146 SharedStorageB operand_B; 147 }; 148 149 protected: 150 // 151 // Data members 152 // 153 154 /// Iterator to load a warp-scoped tile of A operand from shared memory 155 typename Operator::IteratorA warp_tile_iterator_A_; 156 157 /// Iterator to load a warp-scoped tile of B operand from shared memory 158 typename Operator::IteratorB warp_tile_iterator_B_; 159 160 public: 161 /// Construct from tensor references 162 CUTLASS_DEVICE CustomMmaBase(SharedStorageA & shared_storageA,SharedStorageB & shared_storageB,int thread_idx,int warp_idx,int lane_idx)163 CustomMmaBase( 164 ///< Shared storage needed for internal use by threadblock-scoped GEMM 165 SharedStorageA& shared_storageA, 166 SharedStorageB& shared_storageB, 167 ///< ID within the threadblock 168 int thread_idx, 169 ///< ID of warp 170 int warp_idx, 171 ///< ID of each thread within a warp 172 int lane_idx) 173 : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), 174 warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} 175 }; 176 177 ///////////////////////////////////////////////////////////////////////////////////////////////// 178 179 } // namespace threadblock 180 } // namespace gemm 181 } // namespace cutlass 182 183 ///////////////////////////////////////////////////////////////////////////////////////////////// 184