1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Inspired from
34    "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
35    operands from a RowMajor shared-memory layout into registers to use by A100
36    TensorCores.
37 
38     The difference with "mma_tensor_op_tile_access_iterator.h" is that:
39     (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
40    faster) (2) We support to transpose the operand (eg read `A.transpose()` when
41    the shared memory holds `A`)
42 
43     This is only implemented for the specific shapes.
44 */
45 #pragma once
46 
47 #include <cutlass/gemm/gemm.h>
48 
49 ////////////////////////////////////////////////////////////////////////////////
50 namespace cutlass {
51 namespace gemm {
52 namespace warp {
53 
54 template <
55     /// Operand identity
56     Operand Operand_,
57     /// Data type of A elements
58     typename Element_,
59     typename InstructionShape_,
60     bool kTranspose = false>
61 class WarpIteratorFromSmem {
62  public:
63   /// Shape of tile to load (concept: MatrixShape)
64   using Shape = cutlass::MatrixShape<32, 32>;
65 
66   /// Operand tag
67   static Operand const kOperand = Operand_;
68   static_assert(
69       kOperand == Operand::kA,
70       "No support for OperandB at the moment");
71 
72   /// Basic check
73   static_assert(
74       kOperand == Operand::kA || kOperand == Operand::kB,
75       "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma.");
76 
77   /// Element type
78   using Element = Element_;
79   static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
80 
81   /// Layout of source tile
82   using Layout = cutlass::layout::RowMajor;
83 
84   /// Shape of one matrix product operation (concept: MatrixShape)
85   using InstructionShape = InstructionShape_;
86   static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
87   static_assert(
88       InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
89       "Only supports 16x8x8 / 16x8x16");
90 
91   /// Delta between *MMA operations (in units of *MMA operations, concept:
92   /// MatrixShape)
93   static int const kOpDelta = 1;
94 
95   /// Number of participating threads
96   static int const kThreads = 32;
97 
98   /// TensorRef type for loading element from a tensor
99   using TensorRef = TensorRef<Element, Layout>;
100 
101   /// Index type
102   using Index = typename TensorRef::Index;
103 
104   /// Long Index type
105   using LongIndex = typename TensorRef::LongIndex;
106 
107   /// Coordinate for an element in the tensor
108   using TensorCoord = typename TensorRef::TensorCoord;
109 
110   /// Number of elements accessed per Shared Memory load
111   static int const kElementsPerAccess =
112       (sizeof_bits<Element>::value >= 32 ? 1
113                                          : 32 / sizeof_bits<Element>::value);
114 
115   using InstructionCount = MatrixShape<
116       Shape::kRow / InstructionShape::kRow,
117       Shape::kColumn / InstructionShape::kColumn>;
118 
119   static int const kIterations = (kOperand == Operand::kA)
120       ? InstructionCount::kColumn
121       : InstructionCount::kRow;
122 
123  public:
124   //
125   // Derived quantities
126   //
127 
128   /// Fragment object holding a thread's part of a tile
129   using Fragment = Array<
130       Element,
131       (kOperand == Operand::kA)
132           ? (Shape::kRow* InstructionShape::kColumn / kThreads)
133           : (Shape::kColumn* InstructionShape::kRow / kThreads)>;
134 
135   /// Memory access type
136   // using AccessType = AlignedArray<Element, kElementsPerAccess>;
137   using AccessType = Array<unsigned, 4>;
138 
139   static int constexpr kWarpShapeDivisibleInner =
140       (kOperand == Operand::kA ? InstructionShape::kColumn
141                                : InstructionShape::kRow);
142   static int constexpr kAccessesInner =
143       (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
144   // Number of 32bits tiles to load per `ldmatrix`
145   static int const kTilesPerInstruction = InstructionShape::kRow / 8;
146   static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
147 
148  private:
149   /// Underlying tensor reference
150   TensorRef ref_;
151 
152   /// Origin
153   MatrixCoord origin_;
154 
155   /// Iterations in a tile
156   int iterations_;
157 
158  public:
159   /// Constructor from TensorRef
160   CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const & ref,int lane_id)161   WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
162       : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
163   CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const & ref,TensorCoord extent,int lane_id)164   WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
165       : ref_(ref), iterations_(0) {
166     // See also:
167     // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
168     // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
169     // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
170     int ldsm_vec_num = (lane_id >> 3);
171     if (kOperand == Operand::kA) {
172       origin_ = MatrixCoord(lane_id % 8, 0);
173       static_assert(
174           InstructionCount::kRow * kTilesPerInstruction == 4,
175           "can't use ldmatrix.x4");
176       int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
177       int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
178       int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
179       MatrixCoord offset(
180           access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
181           inner_idx * 4 * kElementsPerAccess);
182       if (kTranspose) {
183         offset = MatrixCoord(offset.column(), offset.row());
184       }
185       origin_ += offset;
186     } else {
187       // XXX: This is not tested or used
188       origin_ = MatrixCoord(0, lane_id % 8);
189       static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
190       CUTLASS_PRAGMA_UNROLL
191       for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
192            ++inst_n_idx) {
193         CUTLASS_PRAGMA_UNROLL
194         for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
195           int access_idx = inner_idx + kAccessesInner * inst_n_idx;
196 
197           MatrixCoord offset(
198               inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8);
199 
200           if (access_idx == ldsm_vec_num) {
201             if (kTranspose) {
202               offset = MatrixCoord(offset.column(), offset.row());
203             }
204             origin_ += offset;
205           }
206         }
207       }
208     }
209 
210     ref_.add_coord_offset(origin_);
211   }
212 
213   /// Advances an iterator along logical dimensions of matrix in units of whole
214   /// tiles
215   CUTLASS_HOST_DEVICE
add_tile_offset(TensorCoord const & tile_offset)216   WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
217     TensorCoord coord_offset(
218         tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn);
219     if (kTranspose) {
220       coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
221     }
222     origin_ += coord_offset;
223 
224     ref_.add_coord_offset(coord_offset);
225 
226     return *this;
227   }
228 
229   /// Advances the iterator along the advance dimension
230   CUTLASS_DEVICE
advance()231   void advance() {
232     if (kOperand == Operand::kA) {
233       add_tile_offset({0, 1});
234     } else {
235       add_tile_offset({1, 0});
236     }
237 
238     iterations_ = 0;
239   }
240 
241   /// increase iterations in a tile
242   CUTLASS_HOST_DEVICE
243   WarpIteratorFromSmem& operator++() {
244     iterations_++;
245 
246     if (iterations_ >= kIterations)
247       advance();
248 
249     return *this;
250   }
251 
252   /// Loads a fragment from memory at the location pointed to by the iterator.
253   CUTLASS_DEVICE
load(Fragment & frag)254   void load(Fragment& frag) const {
255     AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
256     using LoadLayout = typename platform::
257         conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
258 
259     CUTLASS_PRAGMA_UNROLL
260     for (int access_m_idx = 0; access_m_idx <
261          (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
262          ++access_m_idx) {
263       MatrixCoord offset;
264       if (kOperand == Operand::kA) {
265         offset = MatrixCoord(
266             access_m_idx * 16, iterations_ * InstructionShape::kColumn);
267       } else {
268         offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
269       }
270       if (kTranspose) {
271         offset = MatrixCoord(offset.column(), offset.row());
272       }
273       cutlass::arch::ldsm<LoadLayout, 4>(
274           access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
275     }
276   }
277 };
278 
279 ////////////////////////////////////////////////////////////////////////////////
280 
281 } // namespace warp
282 } // namespace gemm
283 } // namespace cutlass
284 ////////////////////////////////////////////////////////////////////////////////
285