1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Template for a double-buffered threadblock-scoped GEMM kernel.
34 */
35 
36 #pragma once
37 
38 #include <cutlass/aligned_buffer.h>
39 #include <cutlass/array.h>
40 #include <cutlass/cutlass.h>
41 #include <cutlass/numeric_conversion.h>
42 
43 #include <cutlass/matrix_shape.h>
44 #include <cutlass/numeric_types.h>
45 
46 #include <cutlass/gemm/gemm.h>
47 
48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h>
49 
50 /////////////////////////////////////////////////////////////////////////////////////////////////
51 
52 namespace cutlass {
53 namespace gemm {
54 namespace threadblock {
55 
56 /////////////////////////////////////////////////////////////////////////////////////////////////
57 
58 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
59 /// instructions.
60 template <
61     /// Size of the Gemm problem - concept: gemm::GemmShape<>
62     typename Shape_,
63     /// Iterates over tiles of A operand in global memory
64     //  (concept: ReadableTileIterator | ForwardTileIterator |
65     //  MaskedTileIterator)
66     typename IteratorA_,
67     /// Iterates over tiles of A operand in shared memory
68     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
69     typename SmemIteratorA_,
70     /// Iterates over tiles of B operand in global memory
71     //  (concept: ReadableTileIterator | ForwardTileIterator |
72     //  MaskedTileIterator)
73     typename IteratorB_,
74     /// Iterates over tiles of B operand in shared memory
75     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
76     typename SmemIteratorB_,
77     /// Data type of accumulator matrix
78     typename ElementC_,
79     /// Data type of accumulator matrix
80     typename LayoutC_,
81     /// Policy describing tuning details (concept: MmaPolicy)
82     typename Policy_,
83     /// Transformation applied to A operand
84     typename TransformA_ = NumericArrayConverter<
85         typename SmemIteratorA_::Element,
86         typename IteratorA_::Element,
87         IteratorA_::Fragment::kElements>,
88     ///
89     /// Transformation applied to B operand
90     typename TransformB_ = NumericArrayConverter<
91         typename SmemIteratorB_::Element,
92         typename IteratorB_::Element,
93         IteratorB_::Fragment::kElements>,
94     /// Used for partial specialization
95     typename Enable = bool>
96 class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
97  public:
98   ///< Base class
99   using Base = CustomMmaBase<Shape_, Policy_, 2>;
100 
101   using Shape =
102       Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
103   using IteratorA =
104       IteratorA_; ///< Iterates over tiles of A operand in global memory
105   using IteratorB =
106       IteratorB_; ///< Iterates over tiles of B operand in global memory
107   using ElementC = ElementC_; ///< Data type of accumulator matrix
108   using LayoutC = LayoutC_; ///< Layout of accumulator matrix
109   using Policy = Policy_; ///< Policy describing tuning details
110 
111   using SmemIteratorA = SmemIteratorA_;
112   using SmemIteratorB = SmemIteratorB_;
113 
114   using TransformA = TransformA_;
115   using TransformB = TransformB_;
116 
117   //
118   // Dependent types
119   //
120 
121   /// Fragment of operand A loaded from global memory
122   using FragmentA = typename IteratorA::Fragment;
123 
124   /// Fragment of operand B loaded from global memory
125   using FragmentB = typename IteratorB::Fragment;
126 
127   /// Fragment of accumulator tile
128   using FragmentC = typename Policy::Operator::FragmentC;
129 
130   /// Warp-level Mma
131   using Operator = typename Policy::Operator;
132 
133   /// Obtain the arch tag from the warp-level operator
134   using ArchTag = typename Policy::Operator::ArchTag;
135 
136   /// Complex transform on A operand
137   static ComplexTransform const kTransformA = Operator::kTransformA;
138 
139   /// Complex transform on B operand
140   static ComplexTransform const kTransformB = Operator::kTransformB;
141 
142   // statically assert kStages for MmaPipelined is two (Double-buffered pipeline)
143   static_assert(
144       (Base::kStages == 2),
145       "MmaPipelined requires kStages set to value 2");
146 
147   static bool const kSmemContainsEntireMat = false;
148 
149  private:
150   using WarpFragmentA = typename Operator::FragmentA;
151   using WarpFragmentB = typename Operator::FragmentB;
152 
153  protected:
154   /// Iterator to write threadblock-scoped tile of A operand to shared memory
155   SmemIteratorA smem_iterator_A_;
156 
157   /// Iterator to write threadblock-scoped tile of B operand to shared memory
158   SmemIteratorB smem_iterator_B_;
159 
160  public:
161   /// Construct from tensor references
162   CUTLASS_DEVICE
CustomMmaPipelined(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,int thread_idx,int warp_idx,int lane_idx)163   CustomMmaPipelined(
164       typename Base::SharedStorageA& shared_storageA,
165       typename Base::SharedStorageB& shared_storageB,
166       int thread_idx, ///< ID within the threadblock
167       int warp_idx, ///< ID of warp
168       int lane_idx ///< ID of each thread within a warp
169       )
170       : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
171         smem_iterator_A_(shared_storageA.ref(), thread_idx),
172         smem_iterator_B_(shared_storageB.ref(), thread_idx) {
173     // Compute warp location within threadblock tile by mapping the warp_id to
174     // three coordinates:
175     //   _m: the warp's position within the threadblock along the M dimension
176     //   _n: the warp's position within the threadblock along the N dimension
177     //   _k: the warp's position within the threadblock along the K dimension
178 
179     int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
180     int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
181 
182     int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
183     int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
184 
185     // Add per-warp offsets in units of warp-level tiles
186     this->warp_tile_iterator_A_.add_tile_offset(
187         {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
188     this->warp_tile_iterator_B_.add_tile_offset(
189         {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
190   }
191   CUTLASS_DEVICE
CustomMmaPipelined(typename Base::SharedStorage & st,int thread_idx,int warp_idx,int lane_idx)192   CustomMmaPipelined(
193       ///< Shared storage needed for internal use by threadblock-scoped GEMM
194       typename Base::SharedStorage& st,
195       ///< ID within the threadblock
196       int thread_idx,
197       ///< ID of warp
198       int warp_idx,
199       ///< ID of each thread within a warp
200       int lane_idx)
201       : CustomMmaPipelined(
202             st.operand_A,
203             st.operand_B,
204             thread_idx,
205             warp_idx,
206             lane_idx) {}
207 
208   CUTLASS_DEVICE
set_prologue_done(bool value)209   void set_prologue_done(bool value) {
210     // NOT IMPLEMENTED FOR PIPELINED
211   }
212 
213   CUTLASS_DEVICE
set_zero_outside_bounds(bool value)214   void set_zero_outside_bounds(bool value) {
215     // NOT NEEDED FOR PIPELINED
216     // shared memory will always be zero-filled
217   }
218 
219   template <bool kLoadA = true, bool kLoadB = true>
prologue(typename Base::SharedStorage & shared_storage,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)220   CUTLASS_DEVICE static void prologue(
221       typename Base::SharedStorage& shared_storage,
222       ///< iterator over A operand in global memory
223       IteratorA iterator_A,
224       ///< iterator over B operand in global memory
225       IteratorB iterator_B,
226       int thread_idx,
227       int problem_size_k) {
228     prologue<kLoadA, kLoadB>(
229         shared_storage.operand_A,
230         shared_storage.operand_B,
231         iterator_A,
232         iterator_B,
233         thread_idx,
234         problem_size_k);
235   }
236 
237   template <bool kLoadA = true, bool kLoadB = true>
prologue(typename Base::SharedStorageA & shared_storageA,typename Base::SharedStorageB & shared_storageB,IteratorA iterator_A,IteratorB iterator_B,int thread_idx,int problem_size_k)238   CUTLASS_DEVICE static void prologue(
239       typename Base::SharedStorageA& shared_storageA,
240       typename Base::SharedStorageB& shared_storageB,
241       ///< iterator over A operand in global memory
242       IteratorA iterator_A,
243       ///< iterator over B operand in global memory
244       IteratorB iterator_B,
245       int thread_idx,
246       int problem_size_k) {
247     // NOT IMPLEMENTED FOR PIPELINED
248   }
249 
250   /// Perform a threadblock-scoped matrix multiply-accumulate
251   CUTLASS_DEVICE
operator()252   void operator()(
253       int gemm_k_iterations, ///< number of iterations of the mainloop
254       FragmentC& accum, ///< destination accumulator tile
255       IteratorA iterator_A, ///< iterator over A operand in global memory
256       IteratorB iterator_B, ///< iterator over B operand in global memory
257       FragmentC const& src_accum, ///< source accumulator tile
258       TransformA transform_A =
259           TransformA(), ///< transformation applied to A fragment
260       TransformB transform_B =
261           TransformB()) { ///< transformation applied to B fragment
262 
263     //
264     // Prologue
265     //
266 
267     // Perform accumulation in the 'd' output operand
268     accum = src_accum;
269 
270     FragmentA tb_frag_A;
271     FragmentB tb_frag_B;
272 
273     tb_frag_A.clear();
274     tb_frag_B.clear();
275 
276     // The last kblock is loaded in the prolog
277     iterator_A.load(tb_frag_A);
278     iterator_B.load(tb_frag_B);
279 
280     ++iterator_A;
281     ++iterator_B;
282 
283     this->smem_iterator_A_.store(transform_A(tb_frag_A));
284     this->smem_iterator_B_.store(transform_B(tb_frag_B));
285 
286     ++this->smem_iterator_A_;
287     ++this->smem_iterator_B_;
288 
289     __syncthreads();
290 
291     // Pair of fragments used to overlap shared memory loads and math
292     // instructions
293     WarpFragmentA warp_frag_A[2];
294     WarpFragmentB warp_frag_B[2];
295 
296     this->warp_tile_iterator_A_.set_kgroup_index(0);
297     this->warp_tile_iterator_B_.set_kgroup_index(0);
298 
299     this->warp_tile_iterator_A_.load(warp_frag_A[0]);
300     this->warp_tile_iterator_B_.load(warp_frag_B[0]);
301 
302     ++this->warp_tile_iterator_A_;
303     ++this->warp_tile_iterator_B_;
304 
305     Operator warp_mma;
306 
307     int smem_write_stage_idx = 1;
308 
309     // Avoid reading out of bounds
310     iterator_A.clear_mask(gemm_k_iterations <= 1);
311     iterator_B.clear_mask(gemm_k_iterations <= 1);
312 
313     // Issue loads during the first warp-level matrix multiply-add *AFTER*
314     // issuing shared memory loads (which have the tighest latency requirement).
315 
316     //
317     // Mainloop
318     //
319 
320     // Note: The main loop does not support Base::kWarpGemmIterations == 2.
321     CUTLASS_GEMM_LOOP
322     for (; gemm_k_iterations > 0; --gemm_k_iterations) {
323       //
324       // Loop over GEMM K dimension
325       //
326 
327       CUTLASS_PRAGMA_UNROLL
328       for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
329            ++warp_mma_k) {
330         // Load warp-level tiles from shared memory, wrapping to k offset if
331         // this is the last group as the case may be.
332 
333         if (warp_mma_k == Base::kWarpGemmIterations - 1) {
334           // Write fragments to shared memory
335           this->smem_iterator_A_.store(transform_A(tb_frag_A));
336 
337           this->smem_iterator_B_.store(transform_B(tb_frag_B));
338 
339           __syncthreads();
340 
341           ++this->smem_iterator_A_;
342           ++this->smem_iterator_B_;
343 
344           // Add negative offsets to return iterators to the 'start' of the
345           // circular buffer in shared memory
346           if (smem_write_stage_idx == 1) {
347             this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
348             this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
349           } else {
350             this->warp_tile_iterator_A_.add_tile_offset(
351                 {0,
352                  -Base::kStages * Policy::kPartitionsK *
353                      Base::kWarpGemmIterations});
354             this->warp_tile_iterator_B_.add_tile_offset(
355                 {-Base::kStages * Policy::kPartitionsK *
356                      Base::kWarpGemmIterations,
357                  0});
358           }
359 
360           smem_write_stage_idx ^= 1;
361         }
362 
363         this->warp_tile_iterator_A_.set_kgroup_index(
364             (warp_mma_k + 1) % Base::kWarpGemmIterations);
365         this->warp_tile_iterator_B_.set_kgroup_index(
366             (warp_mma_k + 1) % Base::kWarpGemmIterations);
367 
368         this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
369         this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
370 
371         ++this->warp_tile_iterator_A_;
372         ++this->warp_tile_iterator_B_;
373 
374         if (warp_mma_k == 0) {
375           iterator_A.load(tb_frag_A);
376           iterator_B.load(tb_frag_B);
377 
378           ++iterator_A;
379           ++iterator_B;
380 
381           // Avoid reading out of bounds if this was the last loop iteration
382           iterator_A.clear_mask(gemm_k_iterations <= 2);
383           iterator_B.clear_mask(gemm_k_iterations <= 2);
384         }
385 
386         warp_mma(
387             accum,
388             warp_frag_A[warp_mma_k % 2],
389             warp_frag_B[warp_mma_k % 2],
390             accum);
391       }
392     }
393   }
394 };
395 
396 /////////////////////////////////////////////////////////////////////////////////////////////////
397 
398 } // namespace threadblock
399 } // namespace gemm
400 } // namespace cutlass
401 
402 /////////////////////////////////////////////////////////////////////////////////////////////////
403