1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Inspired from 34 "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM 35 operands from a RowMajor shared-memory layout into registers to use by A100 36 TensorCores. 37 38 The difference with "mma_tensor_op_tile_access_iterator.h" is that: 39 (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly 40 faster) (2) We support to transpose the operand (eg read `A.transpose()` when 41 the shared memory holds `A`) 42 43 This is only implemented for the specific shapes. 44 */ 45 #pragma once 46 47 #include <cutlass/gemm/gemm.h> 48 49 //////////////////////////////////////////////////////////////////////////////// 50 namespace cutlass { 51 namespace gemm { 52 namespace warp { 53 54 template < 55 /// Operand identity 56 Operand Operand_, 57 /// Data type of A elements 58 typename Element_, 59 typename InstructionShape_, 60 bool kTranspose = false> 61 class WarpIteratorFromSmem { 62 public: 63 /// Shape of tile to load (concept: MatrixShape) 64 using Shape = cutlass::MatrixShape<32, 32>; 65 66 /// Operand tag 67 static Operand const kOperand = Operand_; 68 static_assert( 69 kOperand == Operand::kA, 70 "No support for OperandB at the moment"); 71 72 /// Basic check 73 static_assert( 74 kOperand == Operand::kA || kOperand == Operand::kB, 75 "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); 76 77 /// Element type 78 using Element = Element_; 79 static_assert(sizeof_bits<Element>::value == 16, "Only supported for half"); 80 81 /// Layout of source tile 82 using Layout = cutlass::layout::RowMajor; 83 84 /// Shape of one matrix product operation (concept: MatrixShape) 85 using InstructionShape = InstructionShape_; 86 static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); 87 static_assert( 88 InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, 89 "Only supports 16x8x8 / 16x8x16"); 90 91 /// Delta between *MMA operations (in units of *MMA operations, concept: 92 /// MatrixShape) 93 static int const kOpDelta = 1; 94 95 /// Number of participating threads 96 static int const kThreads = 32; 97 98 /// TensorRef type for loading element from a tensor 99 using TensorRef = TensorRef<Element, Layout>; 100 101 /// Index type 102 using Index = typename TensorRef::Index; 103 104 /// Long Index type 105 using LongIndex = typename TensorRef::LongIndex; 106 107 /// Coordinate for an element in the tensor 108 using TensorCoord = typename TensorRef::TensorCoord; 109 110 /// Number of elements accessed per Shared Memory load 111 static int const kElementsPerAccess = 112 (sizeof_bits<Element>::value >= 32 ? 1 113 : 32 / sizeof_bits<Element>::value); 114 115 using InstructionCount = MatrixShape< 116 Shape::kRow / InstructionShape::kRow, 117 Shape::kColumn / InstructionShape::kColumn>; 118 119 static int const kIterations = (kOperand == Operand::kA) 120 ? InstructionCount::kColumn 121 : InstructionCount::kRow; 122 123 public: 124 // 125 // Derived quantities 126 // 127 128 /// Fragment object holding a thread's part of a tile 129 using Fragment = Array< 130 Element, 131 (kOperand == Operand::kA) 132 ? (Shape::kRow* InstructionShape::kColumn / kThreads) 133 : (Shape::kColumn* InstructionShape::kRow / kThreads)>; 134 135 /// Memory access type 136 // using AccessType = AlignedArray<Element, kElementsPerAccess>; 137 using AccessType = Array<unsigned, 4>; 138 139 static int constexpr kWarpShapeDivisibleInner = 140 (kOperand == Operand::kA ? InstructionShape::kColumn 141 : InstructionShape::kRow); 142 static int constexpr kAccessesInner = 143 (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; 144 // Number of 32bits tiles to load per `ldmatrix` 145 static int const kTilesPerInstruction = InstructionShape::kRow / 8; 146 static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); 147 148 private: 149 /// Underlying tensor reference 150 TensorRef ref_; 151 152 /// Origin 153 MatrixCoord origin_; 154 155 /// Iterations in a tile 156 int iterations_; 157 158 public: 159 /// Constructor from TensorRef 160 CUTLASS_HOST_DEVICE WarpIteratorFromSmem(TensorRef const & ref,int lane_id)161 WarpIteratorFromSmem(TensorRef const& ref, int lane_id) 162 : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} 163 CUTLASS_HOST_DEVICE WarpIteratorFromSmem(TensorRef const & ref,TensorCoord extent,int lane_id)164 WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) 165 : ref_(ref), iterations_(0) { 166 // See also: 167 // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 168 // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) 169 // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) 170 int ldsm_vec_num = (lane_id >> 3); 171 if (kOperand == Operand::kA) { 172 origin_ = MatrixCoord(lane_id % 8, 0); 173 static_assert( 174 InstructionCount::kRow * kTilesPerInstruction == 4, 175 "can't use ldmatrix.x4"); 176 int access_m_idx = ldsm_vec_num % kTilesPerInstruction; 177 int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; 178 int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); 179 MatrixCoord offset( 180 access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, 181 inner_idx * 4 * kElementsPerAccess); 182 if (kTranspose) { 183 offset = MatrixCoord(offset.column(), offset.row()); 184 } 185 origin_ += offset; 186 } else { 187 // XXX: This is not tested or used 188 origin_ = MatrixCoord(0, lane_id % 8); 189 static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); 190 CUTLASS_PRAGMA_UNROLL 191 for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; 192 ++inst_n_idx) { 193 CUTLASS_PRAGMA_UNROLL 194 for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { 195 int access_idx = inner_idx + kAccessesInner * inst_n_idx; 196 197 MatrixCoord offset( 198 inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); 199 200 if (access_idx == ldsm_vec_num) { 201 if (kTranspose) { 202 offset = MatrixCoord(offset.column(), offset.row()); 203 } 204 origin_ += offset; 205 } 206 } 207 } 208 } 209 210 ref_.add_coord_offset(origin_); 211 } 212 213 /// Advances an iterator along logical dimensions of matrix in units of whole 214 /// tiles 215 CUTLASS_HOST_DEVICE add_tile_offset(TensorCoord const & tile_offset)216 WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { 217 TensorCoord coord_offset( 218 tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); 219 if (kTranspose) { 220 coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; 221 } 222 origin_ += coord_offset; 223 224 ref_.add_coord_offset(coord_offset); 225 226 return *this; 227 } 228 229 /// Advances the iterator along the advance dimension 230 CUTLASS_DEVICE advance()231 void advance() { 232 if (kOperand == Operand::kA) { 233 add_tile_offset({0, 1}); 234 } else { 235 add_tile_offset({1, 0}); 236 } 237 238 iterations_ = 0; 239 } 240 241 /// increase iterations in a tile 242 CUTLASS_HOST_DEVICE 243 WarpIteratorFromSmem& operator++() { 244 iterations_++; 245 246 if (iterations_ >= kIterations) 247 advance(); 248 249 return *this; 250 } 251 252 /// Loads a fragment from memory at the location pointed to by the iterator. 253 CUTLASS_DEVICE load(Fragment & frag)254 void load(Fragment& frag) const { 255 AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag); 256 using LoadLayout = typename platform:: 257 conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type; 258 259 CUTLASS_PRAGMA_UNROLL 260 for (int access_m_idx = 0; access_m_idx < 261 (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; 262 ++access_m_idx) { 263 MatrixCoord offset; 264 if (kOperand == Operand::kA) { 265 offset = MatrixCoord( 266 access_m_idx * 16, iterations_ * InstructionShape::kColumn); 267 } else { 268 offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); 269 } 270 if (kTranspose) { 271 offset = MatrixCoord(offset.column(), offset.row()); 272 } 273 cutlass::arch::ldsm<LoadLayout, 4>( 274 access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); 275 } 276 } 277 }; 278 279 //////////////////////////////////////////////////////////////////////////////// 280 281 } // namespace warp 282 } // namespace gemm 283 } // namespace cutlass 284 //////////////////////////////////////////////////////////////////////////////// 285