xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 /*! \file
32     \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33 */
34 
35 #pragma once
36 
37 #include <cutlass/aligned_buffer.h>
38 #include <cutlass/array.h>
39 #include <cutlass/cutlass.h>
40 #include <cutlass/numeric_conversion.h>
41 
42 #include <cutlass/matrix_shape.h>
43 #include <cutlass/numeric_types.h>
44 
45 #include <cutlass/gemm/gemm.h>
46 
47 #include <ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_base.h>
48 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h>
49 #include <ATen/native/cuda/cutlass_extensions/interleaved_numeric_conversion.h>
50 
51 #include <ATen/native/cuda/cutlass_extensions/ft_gemm_configs.h>
52 #include <ATen/native/cuda/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h>
53 
54 /////////////////////////////////////////////////////////////////////////////////////////////////
55 
56 namespace cutlass {
57 namespace gemm {
58 namespace threadblock {
59 
60 /////////////////////////////////////////////////////////////////////////////////////////////////
61 
62 /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
63 template<
64     /// Size of the Gemm problem - concept: gemm::GemmShape<>
65     typename Shape_,
66     /// Iterates over tiles of A operand in global memory
67     //  (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
68     typename IteratorA_,
69     /// Iterates over tiles of A operand in shared memory
70     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
71     typename SmemIteratorA_,
72     /// Iterates over tiles of B operand in global memory
73     //  (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
74     typename IteratorB_,
75     /// Iterates over tiles of B operand in shared memory
76     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
77     typename SmemIteratorB_,
78     /// Data type for the scales
79     typename IteratorScale_,
80     /// Iterators over scales in shared memory
81     typename SmemIteratorScale_,
82     /// Data type of accumulator matrix
83     typename ElementC_,
84     /// Data type of accumulator matrix
85     typename LayoutC_,
86     /// Policy describing tuning details (concept: MmaPolicy)
87     typename Policy_,
88     /// Converter for B matrix applied immediately after the LDG (before STS)
89     typename TransformBAfterLDG_,
90     /// Converter for B matrix applited immediately after the LDS
91     typename TransformBAfterLDS_,
92     /// Used for partial specialization
93     typename Enable = bool>
94 class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
95 public:
96     ///< Base class
97     using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
98 
99     using Shape     = Shape_;      ///< Size of the Gemm problem - concept: gemm::GemmShape<>
100     using IteratorA = IteratorA_;  ///< Iterates over tiles of A operand in global memory
101     using IteratorB = IteratorB_;  ///< Iterates over tiles of B operand in global memory
102     using ElementC  = ElementC_;   ///< Data type of accumulator matrix
103     using LayoutC   = LayoutC_;    ///< Layout of accumulator matrix
104     using Policy    = Policy_;     ///< Policy describing tuning details
105 
106     using IteratorScale = IteratorScale_;
107     using ElementScale  = typename IteratorScale::Element;
108     using LayoutScale   = typename IteratorScale::Layout;
109 
110     using SmemIteratorA     = SmemIteratorA_;
111     using SmemIteratorB     = SmemIteratorB_;
112     using SmemIteratorScale = SmemIteratorScale_;
113 
114     using TransformBAfterLDG = TransformBAfterLDG_;
115     using TransformBAfterLDS = TransformBAfterLDS_;
116 
117     //
118     // Dependent types
119     //
120 
121     /// Fragment of operand A loaded from global memory
122     using FragmentA = typename IteratorA::Fragment;
123 
124     /// Fragment of operand B loaded from global memory
125     using FragmentB = typename IteratorB::Fragment;
126 
127     /// Fragment of operand Scale loaded from global memory;
128     using FragmentScale = typename IteratorScale::Fragment;
129 
130     /// Fragment of accumulator tile
131     using FragmentC = typename Policy::Operator::FragmentC;
132 
133     /// Warp-level Mma
134     using Operator = typename Policy::Operator;
135 
136     /// Obtain the arch tag from the warp-level operator
137     using ArchTag = typename Policy::Operator::ArchTag;
138 
139     using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
140                                                      typename Base::WarpGemm,
141                                                      Operand::kB,
142                                                      typename SmemIteratorScale::Fragment::Element,
143                                                      LayoutScale,
144                                                      32>;
145 
146     /// Complex transform on A operand
147     static ComplexTransform const kTransformA = Operator::kTransformA;
148 
149     /// Complex transform on B operand
150     static ComplexTransform const kTransformB = Operator::kTransformB;
151 
152     // statically assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
153     static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
154 
155 private:
156     using WarpFragmentA = typename Operator::FragmentA;
157     using WarpFragmentB = typename Operator::FragmentB;
158     Dequantizer warp_dequantizer_;
159 
160     using ElementB          = typename IteratorB::Element;
161     using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
162 
163     static constexpr bool RequiresTileInterleave =
164         layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
165     static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
166                   "Layout K must match threadblockK");
167 
168 protected:
169     /// Iterator to write threadblock-scoped tile of A operand to shared memory
170     SmemIteratorA smem_iterator_A_;
171 
172     /// Iterator to write threadblock-scoped tile of B operand to shared memory
173     SmemIteratorB smem_iterator_B_;
174 
175     /// Iterator to write threadblock-scoped tile of scale operand to shared memory
176     SmemIteratorScale smem_iterator_scale_;
177 
178 public:
179     /// Construct from tensor references
180     CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage & shared_storage,int thread_idx,int warp_idx,int lane_idx)181     DqMmaPipelined(typename Base::SharedStorage&
182                        shared_storage,  ///< Shared storage needed for internal use by threadblock-scoped GEMM
183                    int thread_idx,      ///< ID within the threadblock
184                    int warp_idx,        ///< ID of warp
185                    int lane_idx         ///< ID of each thread within a warp
186                    ):
187         Base(shared_storage, thread_idx, warp_idx, lane_idx),
188         warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
189                           (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
190                           lane_idx),
191         smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
192         smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
193         smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
194     {
195 
196         // Compute warp location within threadblock tile by mapping the warp_id to
197         // three coordinates:
198         //   _m: the warp's position within the threadblock along the M dimension
199         //   _n: the warp's position within the threadblock along the N dimension
200         //   _k: the warp's position within the threadblock along the K dimension
201 
202         int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
203         int warp_idx_k  = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
204 
205         int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
206         int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
207 
208         // Add per-warp offsets in units of warp-level tiles
209         this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
210         this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
211     }
212 
213     /// Perform a threadblock-scoped matrix multiply-accumulate
214     CUTLASS_DEVICE
operator()215     void operator()(int              gemm_k_iterations,  ///< number of iterations of the mainloop
216                     FragmentC&       accum,              ///< destination accumulator tile
217                     IteratorA        iterator_A,         ///< iterator over A operand in global memory
218                     IteratorB        iterator_B,         ///< iterator over B operand in global memory
219                     IteratorScale    iterator_scale,     ///< iterator over scale operand in global memory
220                     FragmentC const& src_accum)
221     {  ///< source accumulator tile
222 
223         //
224         // Prologue
225         //
226         TransformBAfterLDG ldg_converter;
227         TransformBAfterLDS lds_converter;
228 
229         using TransformA =
230             NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
231 
232         using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
233                                                      typename FragmentScale::Element,
234                                                      FragmentScale::kElements>;
235 
236         // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
237         // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
238         TransformA     transformA;
239         TransformScale transformScale;
240 
241         // Perform accumulation in the 'd' output operand
242         accum = src_accum;
243 
244         FragmentA     tb_frag_A;
245         FragmentB     tb_frag_B;
246         FragmentScale tb_frag_scales;
247 
248         using WarpFragmentScale = typename Dequantizer::FragmentScale;
249         WarpFragmentScale warp_frag_scales;
250 
251         tb_frag_A.clear();
252         tb_frag_B.clear();
253         tb_frag_scales.clear();
254 
255         // The last kblock is loaded in the prolog
256         iterator_A.load(tb_frag_A);
257         iterator_B.load(tb_frag_B);
258         iterator_scale.load(tb_frag_scales);
259 
260         ++iterator_A;
261         ++iterator_B;
262 
263         this->smem_iterator_A_.store(transformA(tb_frag_A));
264         this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
265         this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
266 
267         ++this->smem_iterator_A_;
268         ++this->smem_iterator_B_;
269 
270         __syncthreads();
271 
272         warp_dequantizer_.load(warp_frag_scales);
273 
274         // Pair of fragments used to overlap shared memory loads and math instructions
275         WarpFragmentA warp_frag_A[2];
276         WarpFragmentB warp_frag_B[2];
277 
278         this->warp_tile_iterator_A_.set_kgroup_index(0);
279         this->warp_tile_iterator_B_.set_kgroup_index(0);
280 
281         this->warp_tile_iterator_A_.load(warp_frag_A[0]);
282         this->warp_tile_iterator_B_.load(warp_frag_B[0]);
283 
284         ++this->warp_tile_iterator_A_;
285         ++this->warp_tile_iterator_B_;
286 
287         Operator warp_mma;
288 
289         int smem_write_stage_idx = 1;
290 
291         // Avoid reading out of bounds
292         iterator_A.clear_mask(gemm_k_iterations <= 1);
293         iterator_B.clear_mask(gemm_k_iterations <= 1);
294 
295         // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
296         // shared memory loads (which have the tighest latency requirement).
297 
298         //
299         // Mainloop
300         //
301 
302         // Note: The main loop does not support Base::kWarpGemmIterations == 2.
303         CUTLASS_GEMM_LOOP
304         for (; gemm_k_iterations > 0; --gemm_k_iterations) {
305             //
306             // Loop over GEMM K dimension
307             //
308 
309             CUTLASS_PRAGMA_UNROLL
310             for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
311 
312                 // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
313                 // as the case may be.
314 
315                 if (warp_mma_k == Base::kWarpGemmIterations - 1) {
316 
317                     // Write fragments to shared memory
318                     this->smem_iterator_A_.store(transformA(tb_frag_A));
319 
320                     this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
321 
322                     __syncthreads();
323 
324                     ++this->smem_iterator_A_;
325                     ++this->smem_iterator_B_;
326 
327                     // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
328                     if (smem_write_stage_idx == 1) {
329                         this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
330                         this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
331                     }
332                     else {
333                         this->warp_tile_iterator_A_.add_tile_offset(
334                             {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
335                         this->warp_tile_iterator_B_.add_tile_offset(
336                             {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
337                     }
338 
339                     smem_write_stage_idx ^= 1;
340                 }
341 
342                 this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
343                 this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
344                 ++this->warp_tile_iterator_A_;
345 
346                 const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
347                 const int warp_tileB_k_load_offset    = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
348                 // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
349                 if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
350                     this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
351                                                                  % Base::kWarpGemmIterationsForB);
352                     this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
353                     ++this->warp_tile_iterator_B_;
354                 }
355 
356                 if (warp_mma_k == 0) {
357 
358                     iterator_A.load(tb_frag_A);
359                     iterator_B.load(tb_frag_B);
360 
361                     ++iterator_A;
362                     ++iterator_B;
363 
364                     // Avoid reading out of bounds if this was the last loop iteration
365                     iterator_A.clear_mask(gemm_k_iterations <= 2);
366                     iterator_B.clear_mask(gemm_k_iterations <= 2);
367                 }
368 
369                 typename TransformBAfterLDS::result_type converted_frag_B =
370                     lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
371                 warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
372                 run_warp_mma(
373                     warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
374             }
375         }
376     }
377 };
378 
379 /////////////////////////////////////////////////////////////////////////////////////////////////
380 
381 }  // namespace threadblock
382 }  // namespace gemm
383 }  // namespace cutlass
384 
385 /////////////////////////////////////////////////////////////////////////////////////////////////
386