xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Template for a double-buffered threadblock-scoped GEMM kernel.
34 */
35 
36 #pragma once
37 
38 #include <cutlass/aligned_buffer.h>
39 #include <cutlass/arch/memory.h>
40 #include <cutlass/array.h>
41 #include <cutlass/cutlass.h>
42 #include <cutlass/epilogue/thread/linear_combination.h>
43 #include <cutlass/epilogue/threadblock/default_epilogue_simt.h>
44 #include <cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
45 #include <cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h>
46 #include <cutlass/functional.h>
47 #include <cutlass/gemm/gemm.h>
48 #include <cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h>
49 #include <cutlass/matrix_shape.h>
50 #include <cutlass/numeric_conversion.h>
51 #include <cutlass/numeric_types.h>
52 #include <cutlass/platform/platform.h>
53 #include <cutlass/transform/threadblock/vector_iterator.h>
54 
55 #include <cutlass/epilogue/threadblock/epilogue_smem_accumulator.h>
56 #include <cutlass/gemm/threadblock/mma_base.h>
57 #include <cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h>
58 #include <cutlass/gemm/threadblock/mma_pipelined.h>
59 #include <cutlass/gemm/threadblock/mma_multistage.h>
60 
61 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h>
62 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
63 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h>
64 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h>
65 
66 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h>
67 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h>
68 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h>
69 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h>
70 
71 /////////////////////////////////////////////////////////////////////////////////////////////////
72 
73 namespace cutlass {
74 namespace gemm {
75 namespace threadblock {
76 
77 /// Shared storage object needed by accumulator
78 /// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h
79 template <
80     typename Shape_,
81     typename Element_,
82     typename Layout_,
83     typename Padding_>
84 class AccumulatorSharedStorage {
85  public:
86   //
87   // Type definitions
88   //
89   using Shape = Shape_;
90   using Element = Element_;
91   using Layout = Layout_;
92   using Padding = Padding_;
93 
94   /// Tensor reference to the accumulator
95   using TensorRefAccum = cutlass::TensorRef<Element, Layout>;
96 
97   /// Shape of the accumulator matrix in shared memory
98   using ShapeAccum = cutlass::
99       MatrixShape<Shape::kM + Padding::kRow, Shape::kN + Padding::kColumn>;
100 
101  public:
102   //
103   // Data members
104   //
105 
106   /// Buffer for accumulator
107   cutlass::AlignedBuffer<Element, ShapeAccum::kCount> accum;
108 
109  public:
110   //
111   // Methods
112   //
113 
114   /// Returns a layout object for the Accum matrix
115   CUTLASS_DEVICE
LayoutAccum()116   static Layout LayoutAccum() {
117     return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn});
118   }
119 
120   /// Returns a TensorRef to the Accumulator
121   CUTLASS_HOST_DEVICE
accum_ref()122   TensorRefAccum accum_ref() {
123     return TensorRefAccum{accum.data(), LayoutAccum()};
124   }
125 };
126 
127 ////////////////////////////////////////////////////////////////////////////////
128 // Taken from
129 // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h
130 ////////////////////////////////////////////////////////////////////////////////
131 
132 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
133 /// instructions.
134 template <
135     /// Size of the Gemm problem - concept: gemm::GemmShape<>
136     typename Shape_,
137     // Maximum K dimension - also the dimension of the shared-memory
138     // holding `OperandA`
139     int kMaxK_,
140     /// Policy describing tuning details (concept: MmaPolicy)
141     typename Policy_,
142     /// Number of stages,
143     int Stages,
144     /// Layout in shared-memory of operand A
145     typename SmemLayoutA,
146     /// Used for partial specialization
147     typename Enable = bool>
148 class MmaBaseFromSharedMemory {
149  public:
150   ///< Size of the Gemm problem - concept: gemm::GemmShape<>
151   using Shape = Shape_;
152   static constexpr int kMaxK = kMaxK_;
153 
154   ///< Policy describing tuning details
155   using Policy = Policy_;
156 
157   //
158   // Dependent types
159   //
160 
161   /// Warp-level Mma
162   using Operator = typename Policy::Operator;
163 
164   /// Shape describing the overall GEMM computed from shared memory
165   /// by each warp.
166   using WarpGemm = typename Policy::Operator::Shape;
167 
168   /// Shape describing the number of warps filling the CTA
169   using WarpCount = GemmShape<
170       Shape::kM / WarpGemm::kM,
171       Shape::kN / WarpGemm::kN,
172       Shape::kK / WarpGemm::kK>;
173   using WarpCount1 = WarpCount;
174 
175   /// Number of warp-level GEMM operations
176   static int const kWarpGemmIterations =
177       (WarpGemm::kK / Operator::Policy::MmaShape::kK);
178   static int const kWarpGemmIterations1 = kWarpGemmIterations;
179 
180   /// Number of stages
181   static int const kStages = Stages;
182 
183   /// If this is true, we fill the entire shmem buffer at start
184   /// and don't need to iterate through it in a circular fashion
185   static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
186 
187   /// Tensor reference to the A operand
188   using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
189 
190   /// Tensor reference to the B operand
191   using TensorRefB =
192       TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
193 
194   //
195   // Nested structs
196   //
197 
198   /// Shared storage object needed by threadblock-scoped GEMM
199   class SharedStorage {
200    public:
201     //
202     // Type definitions
203     //
204 
205     /// Shape of the B matrix operand in shared memory
206     using ShapeB = MatrixShape<
207         Shape::kK * kStages + Policy::SmemPaddingB::kRow,
208         Shape::kN + Policy::SmemPaddingB::kColumn>;
209 
210    public:
211     //
212     // Data members
213     //
214 
215     /// Buffer for B operand
216     AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
217 
218    public:
219     //
220     // Methods
221     //
222 
223     /// Returns a layout object for the B matrix
224     CUTLASS_HOST_DEVICE
LayoutB()225     static typename Operator::LayoutB LayoutB() {
226       return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
227     }
228 
229     /// Returns a TensorRef to the B operand
230     CUTLASS_HOST_DEVICE
operand_B_ref()231     TensorRefB operand_B_ref() {
232       return TensorRefB{operand_B.data(), LayoutB()};
233     }
234   };
235 
236  protected:
237   //
238   // Data members
239   //
240 
241   // /// Iterator to load a warp-scoped tile of A operand from shared memory
242   // typename Operator::IteratorA warp_tile_iterator_A_;
243 
244   /// Iterator to load a warp-scoped tile of B operand from shared memory
245   typename Operator::IteratorB warp_tile_iterator_B_;
246 
247  public:
248   /// Construct from tensor references
249   CUTLASS_DEVICE
MmaBaseFromSharedMemory(TensorRefB & b_tile,int thread_idx,int warp_idx,int lane_idx)250   MmaBaseFromSharedMemory(
251       ///< Shared storage needed for internal use by threadblock-scoped GEMM
252       TensorRefB& b_tile,
253       ///< ID within the threadblock
254       int thread_idx,
255       ///< ID of warp
256       int warp_idx,
257       ///< ID of each thread within a warp
258       int lane_idx)
259       : warp_tile_iterator_B_(b_tile, lane_idx) {}
260 };
261 
262 namespace {
263 
264 // has necessary trait compliance with WarpIteratorFromSmem but doesn't do
265 // anything, can be default initialized, and uses fragment that takes up
266 // (almost) no space. this warp iterator is selected at compile time when
267 // elementwise on-the-fly scaling for operand A is disabled, in which case
268 // operations related to loading scale factors for operand A get wiped out by
269 // the compiler.
270 template <typename TensorRef>
271 class NoOpWarpIteratorScale {
272  public:
273   // in pipelined+multistage MMA implementations we keep an array of fragments.
274   // if we aren't using scaling we don't want to waste registers on fragments
275   // of scale elements, so ideally this would be sized 0.
276   // Since arrays of zero-sized objects are not allowed, using size as 1.
277   // The compiler will most likely wipe it out anyways.
278   using Fragment = cutlass::Array<char, 1>;
279 
280   CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale()281   NoOpWarpIteratorScale() {}
282 
283   CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale(TensorRef const &,int)284   NoOpWarpIteratorScale(TensorRef const&, int) {}
285 
286   CUTLASS_HOST_DEVICE
add_tile_offset(typename TensorRef::TensorCoord const &)287   NoOpWarpIteratorScale& add_tile_offset(
288       typename TensorRef::TensorCoord const&) {
289     return *this;
290   }
291 
292   CUTLASS_HOST_DEVICE
293   NoOpWarpIteratorScale& operator++() {
294     return *this;
295   }
296 
297   CUTLASS_DEVICE
load(Fragment &)298   void load(Fragment&) const {}
299 };
300 
301 // if scaling is enabled, performs fragment elementwise multiplication between
302 // fragment and its scaling factor.
303 template <typename Fragment, typename FragmentScale, bool ScalingEnabled>
304 class FragmentElementwiseScaler;
305 
306 // specialization for scaling being enabled.
307 template <typename Fragment, typename FragmentScale>
308 class FragmentElementwiseScaler<Fragment, FragmentScale, true> {
309  public:
310   // cast scale_frag to correct type then apply elementwise to fragment
311   CUTLASS_DEVICE
apply(Fragment frag,FragmentScale const & scale_frag)312   static Fragment apply(Fragment frag, FragmentScale const& scale_frag) {
313     Fragment converted_scale_frag = cutlass::NumericArrayConverter<
314         typename Fragment::Element,
315         typename FragmentScale::Element,
316         FragmentScale::kElements>()(scale_frag);
317     return cutlass::multiplies<Fragment>()(frag, converted_scale_frag);
318   }
319 };
320 
321 // specialization for scaling being disabled. doesn't do anything and should
322 // just get wiped out by the compiler.
323 template <typename Fragment, typename FragmentScale>
324 class FragmentElementwiseScaler<Fragment, FragmentScale, false> {
325  public:
326   CUTLASS_DEVICE
apply(Fragment frag,FragmentScale const &)327   static Fragment apply(Fragment frag, FragmentScale const&) {
328     return frag;
329   }
330 };
331 } // namespace
332 
333 ////////////////////////////////////////////////////////////////////////////////
334 // Taken from
335 // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h
336 ////////////////////////////////////////////////////////////////////////////////
337 
338 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
339 /// instructions.
340 template <
341     /// Size of the Gemm problem - concept: gemm::GemmShape<>
342     typename Shape_,
343     // BEGIN smem
344     /// Iterates over the intermediate accumulator tile in shared memory
345     typename WarpIteratorA_,
346     /// whether or not to perform elementwise multiplication of A
347     //  by another matrix (A_scale) that is also kept in shared memory prior
348     //  to matmul A @ B
349     bool ScaleOperandA_,
350     /// Max GEMM problem size in K dimension
351     int MaxK,
352     /// Iterates over tiles of B operand in global memory
353     //  (concept: ReadableTileIterator | ForwardTileIterator |
354     //  MaskedTileIterator)
355     typename IteratorB_,
356     /// Iterates over tiles of B operand in shared memory
357     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
358     typename SmemIteratorB_,
359     /// Data type of accumulator matrix
360     typename ElementC_,
361     /// Data type of accumulator matrix
362     typename LayoutC_,
363     /// Policy describing tuning details (concept: MmaPolicy)
364     typename Policy_,
365     /// Transformation applied to B operand
366     typename TransformB_ = NumericArrayConverter<
367         typename SmemIteratorB_::Element,
368         typename IteratorB_::Element,
369         IteratorB_::Fragment::kElements>,
370     /// Used for partial specialization
371     typename Enable = bool>
372 class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
373                                          Shape_,
374                                          MaxK,
375                                          Policy_,
376                                          2,
377                                          typename WarpIteratorA_::Layout> {
378  public:
379   ///< Base class
380   using Base = MmaBaseFromSharedMemory<
381       Shape_,
382       MaxK,
383       Policy_,
384       2,
385       typename WarpIteratorA_::Layout>;
386 
387   using Shape =
388       Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
389   static constexpr bool ScaleOperandA = ScaleOperandA_;
390 
391   using WarpIteratorA = WarpIteratorA_;
392   ///< loads fragments of A_scale from shared memory if operand A scaling is
393   ///< enabled. otherwise no-op.
394   using WarpIteratorAScale = typename cutlass::platform::conditional<
395       ScaleOperandA,
396       WarpIteratorA,
397       NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type;
398 
399   using IteratorB =
400       IteratorB_; ///< Iterates over tiles of B operand in global memory
401   using ElementC = ElementC_; ///< Data type of accumulator matrix
402   using LayoutC = LayoutC_; ///< Layout of accumulator matrix
403   using Policy = Policy_; ///< Policy describing tuning details
404 
405   using SmemIteratorB = SmemIteratorB_;
406 
407   using TransformB = TransformB_;
408 
409   //
410   // Dependent types
411   //
412 
413   /// Fragment of operand B loaded from global memory
414   using FragmentB = typename IteratorB::Fragment;
415 
416   /// Fragment of accumulator tile
417   using FragmentC = typename Policy::Operator::FragmentC;
418 
419   /// Warp-level Mma
420   using Operator = typename Policy::Operator;
421 
422   /// Obtain the arch tag from the warp-level operator
423   using ArchTag = typename Policy::Operator::ArchTag;
424 
425   /// Complex transform on B operand
426   static ComplexTransform const kTransformB = Operator::kTransformB;
427 
428   // statically assert kStages for MmaPipelined is two (Double-buffered pipeline)
429   static_assert(
430       (Base::kStages == 2),
431       "MmaPipelined requires kStages set to value 2");
432 
433  private:
434   using WarpFragmentA = typename Operator::FragmentA;
435 
436   /// fragment type of OperandA elementwise scaling matrix. (almost) empty
437   /// if operand A scaling is disabled.
438   using WarpFragmentAScale = typename WarpIteratorAScale::Fragment;
439 
440   using WarpFragmentB = typename Operator::FragmentB;
441 
442   /// applies scaling factor to operand A fragment if operand A scaling is
443   /// enabled. otherwise no-op.
444   using FragmentAScaler = FragmentElementwiseScaler<
445       WarpFragmentA,
446       WarpFragmentAScale,
447       ScaleOperandA>;
448 
449  protected:
450   // /// Iterator to write threadblock-scoped tile of A operand to shared memory
451   // SmemIteratorA smem_iterator_A_;
452 
453   /// Iterator to write threadblock-scoped tile of B operand to shared memory
454   SmemIteratorB smem_iterator_B_;
455 
456   /// Iterator to load a warp-scoped tile of A operand from intermediate
457   /// accumulator tile
458   WarpIteratorA warp_tile_iterator_A_;
459 
460   /// Iterator to load a warp-scoped tile of A_scale from intermediate
461   /// accumulator tile (only used if ScaleOperandA_ is true)
462   WarpIteratorAScale warp_tile_iterator_A_scale_;
463 
464  public:
465   /// constructor for MMA with operand A scaling enabled.
466   CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefA a_scale,typename Base::TensorRefB b_staging,int thread_idx,int warp_idx,int lane_idx)467   MmaPipelinedFromSharedMemory(
468       typename Base::TensorRefA a, // Operand A in shared memory
469       typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
470       typename Base::TensorRefB
471           b_staging, // staging memory for loading tiles of B
472       int thread_idx,
473       int warp_idx,
474       int lane_idx)
475       : Base(b_staging, thread_idx, warp_idx, lane_idx),
476         warp_tile_iterator_A_(a, lane_idx),
477         warp_tile_iterator_A_scale_(a_scale, lane_idx),
478         smem_iterator_B_(b_staging, thread_idx) {
479     // Compute warp location within threadblock tile by mapping the warp_id to
480     // three coordinates:
481     //   _m: the warp's position within the threadblock along the M dimension
482     //   _n: the warp's position within the threadblock along the N dimension
483     //   _k: the warp's position within the threadblock along the K dimension
484     int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
485     int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
486     int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
487     int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
488 
489     // Add per-warp offsets in units of warp-level tiles
490     this->warp_tile_iterator_A_.add_tile_offset(
491         {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
492     this->warp_tile_iterator_A_scale_.add_tile_offset(
493         {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
494     this->warp_tile_iterator_B_.add_tile_offset(
495         {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
496   }
497 
498   /// Construct from tensor references
499   CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefB b_staging,int thread_idx,int warp_idx,int lane_idx)500   MmaPipelinedFromSharedMemory(
501       typename Base::TensorRefA a, ///< Operand A in shared memory
502       typename Base::TensorRefB b_staging, ///< staging memory for loading B
503       int thread_idx, ///< ID within the threadblock
504       int warp_idx, ///< ID of warp
505       int lane_idx) ///< ID of each thread within a warp
506       : Base(b_staging, thread_idx, warp_idx, lane_idx),
507         warp_tile_iterator_A_(a, lane_idx),
508         smem_iterator_B_(b_staging, thread_idx) {
509     // Compute warp location within threadblock tile by mapping the warp_id to
510     // three coordinates:
511     //   _m: the warp's position within the threadblock along the M dimension
512     //   _n: the warp's position within the threadblock along the N dimension
513     //   _k: the warp's position within the threadblock along the K dimension
514 
515     int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
516     int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
517 
518     int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
519     int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
520 
521     // Add per-warp offsets in units of warp-level tiles
522     this->warp_tile_iterator_A_.add_tile_offset(
523         {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
524     this->warp_tile_iterator_B_.add_tile_offset(
525         {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
526   }
527 
528   // For API compatibility with MmaMultistageFromSharedMemory
529   // but not supported as it worsens perf: older gpus < sm80 don't
530   // support async transfers and have to waste registers
531   CUTLASS_DEVICE
set_prologue_done(bool value)532   void set_prologue_done(bool value) {}
533   CUTLASS_DEVICE
prologue(typename Base::SharedStorage & shared_storage,IteratorB iterator_B1,int thread_idx,int problem_size_0_n)534   static void prologue(
535       typename Base::SharedStorage& shared_storage,
536       IteratorB iterator_B1,
537       int thread_idx,
538       int problem_size_0_n) {}
539 
540   /// Perform a threadblock-scoped matrix multiply-accumulate
541   CUTLASS_DEVICE
operator()542   void operator()(
543       int gemm_k_iterations, ///< number of iterations of the mainloop
544       FragmentC& accum, ///< destination accumulator tile
545       // IteratorA iterator_A,                             ///< iterator over A
546       // operand in global memory
547       IteratorB iterator_B, ///< iterator over B operand in global memory
548       FragmentC const& src_accum, ///< source accumulator tile
549       // TransformA transform_A = TransformA(),            ///< transformation
550       // applied to A fragment
551       TransformB transform_B =
552           TransformB()) { ///< transformation applied to B fragment
553 
554     //
555     // Prologue
556     //
557 
558     // Perform accumulation in the 'd' output operand
559     accum = src_accum;
560 
561     FragmentB tb_frag_B;
562 
563     tb_frag_B.clear();
564 
565     // The last kblock is loaded in the prolog
566     iterator_B.set_residual_tile(gemm_k_iterations == 1);
567     iterator_B.load(tb_frag_B);
568 
569     ++iterator_B;
570 
571     this->smem_iterator_B_.store(transform_B(tb_frag_B));
572 
573     ++this->smem_iterator_B_;
574 
575     __syncthreads();
576 
577     // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op
578     // if scaling is disabled.
579 
580     // Pair of fragments used to overlap shared memory loads and math
581     // instructions
582     WarpFragmentA warp_frag_A[2];
583     WarpFragmentAScale warp_frag_A_scale[2];
584     WarpFragmentB warp_frag_B[2];
585     warp_frag_A[0].clear();
586     warp_frag_A_scale[0].clear();
587     warp_frag_B[0].clear();
588 
589     this->warp_tile_iterator_B_.set_kgroup_index(0);
590 
591     this->warp_tile_iterator_A_.load(warp_frag_A[0]);
592     this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]);
593     this->warp_tile_iterator_B_.load(warp_frag_B[0]);
594 
595     ++this->warp_tile_iterator_A_;
596     ++this->warp_tile_iterator_A_scale_;
597     ++this->warp_tile_iterator_B_;
598 
599     Operator warp_mma;
600 
601     int smem_write_stage_idx = 1;
602 
603     // Avoid reading out of bounds
604     iterator_B.set_residual_tile(gemm_k_iterations == 2);
605     iterator_B.clear_mask(gemm_k_iterations <= 1);
606 
607     // Issue loads during the first warp-level matrix multiply-add *AFTER*
608     // issuing shared memory loads (which have the tightest latency
609     // requirement).
610 
611     //
612     // Mainloop
613     //
614 
615     // Note: The main loop does not support Base::kWarpGemmIterations == 2.
616     CUTLASS_GEMM_LOOP
617     for (; gemm_k_iterations > 0; --gemm_k_iterations) {
618       //
619       // Loop over GEMM K dimension
620       //
621 
622       CUTLASS_PRAGMA_UNROLL
623       for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
624            ++warp_mma_k) {
625         // Load warp-level tiles from shared memory, wrapping to k offset if
626         // this is the last group as the case may be.
627         bool hasNext = true;
628 
629         if (warp_mma_k == Base::kWarpGemmIterations - 1) {
630           if (gemm_k_iterations > 1) {
631             // Write fragments to shared memory
632             this->smem_iterator_B_.store(transform_B(tb_frag_B));
633           }
634 
635           __syncthreads();
636 
637           ++this->smem_iterator_B_;
638 
639           // Add negative offsets to return iterators to the 'start' of the
640           // circular buffer in shared memory SMEM: Don't reset iterator A, as
641           // we are continuing our iteration at this point
642           if (smem_write_stage_idx == 1) {
643             this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
644           } else {
645             this->warp_tile_iterator_B_.add_tile_offset(
646                 {-Base::kStages * Policy::kPartitionsK *
647                      Base::kWarpGemmIterations,
648                  0});
649           }
650 
651           smem_write_stage_idx ^= 1;
652           hasNext = gemm_k_iterations > 1;
653         }
654 
655         // Only read the next if we need to
656         if (hasNext) {
657           this->warp_tile_iterator_B_.set_kgroup_index(
658               (warp_mma_k + 1) % Base::kWarpGemmIterations);
659 
660           this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
661           this->warp_tile_iterator_A_scale_.load(
662               warp_frag_A_scale[(warp_mma_k + 1) % 2]);
663           this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
664 
665           ++this->warp_tile_iterator_A_;
666           ++this->warp_tile_iterator_A_scale_;
667           ++this->warp_tile_iterator_B_;
668 
669           if (warp_mma_k == 0) {
670             iterator_B.load(tb_frag_B);
671 
672             ++iterator_B;
673 
674             // Avoid reading out of bounds if this was the last loop iteration
675             iterator_B.set_residual_tile(gemm_k_iterations == 3);
676             iterator_B.clear_mask(gemm_k_iterations <= 2);
677           }
678         }
679 
680         warp_mma(
681             accum,
682             FragmentAScaler::apply(
683                 warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]),
684             warp_frag_B[warp_mma_k % 2],
685             accum);
686       }
687     }
688   }
689 };
690 
691 ////////////////////////////////////////////////////////////////////////////////
692 // Taken from
693 // https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h
694 ////////////////////////////////////////////////////////////////////////////////
695 
696 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
697 /// instructions.
698 template <
699     /// Size of the Gemm problem - concept: gemm::GemmShape<>
700     typename Shape1_,
701     /// Iterates over the intermediate accumulator tile in shared memory
702     typename WarpIteratorA1_,
703     /// whether or not to perform elementwise multiplication of A
704     //  by another matrix (A_scale) that is also kept in shared memory prior
705     //  to matmul A @ B
706     bool ScaleOperandA_,
707     /// Iterates over tiles of B operand in global memory
708     //  (concept: ReadableTileIterator | ForwardTileIterator |
709     //  MaskedTileIterator)
710     typename IteratorB1_,
711     /// Iterates over tiles of B operand in shared memory
712     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
713     typename SmemIteratorB1_,
714     /// Cache operation for operand B
715     cutlass::arch::CacheOperation::Kind CacheOpB1,
716     /// Data type of accumulator matrix
717     typename ElementC_,
718     /// Data type of accumulator matrix
719     typename LayoutC_,
720     /// Policy describing tuning details (concept: MmaPolicy)
721     typename Policy1_,
722     /// Number of stages,
723     int Stages_,
724     int kMaxK_,
725     /// Used for partial specialization
726     typename Enable = bool>
727 class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
728                                           Shape1_,
729                                           kMaxK_,
730                                           Policy1_,
731                                           Stages_,
732                                           typename WarpIteratorA1_::Layout> {
733  public:
734   ///< Base class
735   using Base = MmaBaseFromSharedMemory<
736       Shape1_,
737       kMaxK_,
738       Policy1_,
739       Stages_,
740       typename WarpIteratorA1_::Layout>;
741 
742   ///< Size of the Gemm problem - concept: gemm::GemmShape<>
743   using Shape1 = Shape1_;
744   ///< Iterates over tiles of B operand in global memory
745   using IteratorB1 = IteratorB1_;
746   using IteratorB = IteratorB1;
747   ///< Policy describing tuning details
748   using Policy1 = Policy1_;
749 
750   using SmemIteratorB1 = SmemIteratorB1_;
751   using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate
752                                           ///< accumulator tile in shared memory
753   static constexpr bool ScaleOperandA = ScaleOperandA_;
754 
755   ///< warp level iterator over A_scale matrix tile kept in shared memory.
756   ///< if elementwise A scaling is disabled then everything this does is no-op.
757   using WarpIteratorAScale = typename cutlass::platform::conditional<
758       ScaleOperandA,
759       WarpIteratorA1,
760       NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type;
761   ///< Data type of accumulator matrix
762   using ElementC = ElementC_;
763   ///< Layout of accumulator matrix
764   using LayoutC = LayoutC_;
765 
766   static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
767   static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB;
768 
769   //
770   // Dependent types
771   //
772 
773   /// Fragment of accumulator tile
774   using FragmentC1 = typename Policy1::Operator::FragmentC;
775   using FragmentC = FragmentC1;
776 
777   /// Warp-level Mma
778   using Operator1 = typename Policy1::Operator;
779 
780   /// Minimum architecture is Sm80 to support cp.async
781   using ArchTag = arch::Sm80;
782 
783   /// Complex transform on B operand
784   static ComplexTransform const kTransformB1 = Operator1::kTransformB;
785 
786   /// Internal structure exposed for introspection.
787   struct Detail {
788     static_assert(
789         Base::kWarpGemmIterations1 > 1,
790         "The pipelined structure requires at least two warp-level "
791         "GEMM operations.");
792 
793     /// Number of cp.async instructions to load one stage of operand B
794     static int const TBLoadIterationsB1 =
795         IteratorB1::ThreadMap::Iterations::kCount;
796 
797     /// Number of cp.async instructions to load on group of operand B
798     static int const kAccessesPerGroupB1 =
799         (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) /
800         Base::kWarpGemmIterations1;
801   };
802 
803   static constexpr int kNumStagesConcurrentLoad =
804       kSmemContainsEntireB ? Base::kStages : Base::kStages - 1;
805 
806  private:
807   using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
808   /// fragment of OperandA scale matrix. if operand A scaling is disabled this
809   /// is (almost) empty.
810   using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment;
811   using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
812   using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
813   using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
814 
815   /// applies elementwise scaling to fragment of A. if operand A scaling is
816   /// disabled this is a no-op.
817   using FragmentAScaler = FragmentElementwiseScaler<
818       WarpLoadedFragmentA1,
819       WarpLoadedFragmentA1Scale,
820       ScaleOperandA>;
821 
822  private:
823   //
824   // Data members
825   //
826 
827   /// Iterator to load a warp-scoped tile of A1 operand from intermediate
828   /// accumulator tile
829   WarpIteratorA1 warp_tile_iterator_A1_;
830 
831   /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory
832   /// if operand A scaling is disabled everything this does is a no-op.
833   WarpIteratorAScale warp_tile_iterator_A1_scale_;
834 
835   /// Iterator to write threadblock-scoped tile of B operand to shared memory
836   SmemIteratorB1 smem_iterator_B1_;
837 
838   bool prologue_done_;
839 
840  public:
841   /// constructor for MMA with operand A scaling enabled.
842   CUTLASS_DEVICE
MmaMultistageFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefA a_scale,typename Base::TensorRefB b_tile,int thread_idx,int warp_idx,int lane_idx)843   MmaMultistageFromSharedMemory(
844       typename Base::TensorRefA a,
845       typename Base::TensorRefA a_scale,
846       typename Base::TensorRefB b_tile,
847       int thread_idx,
848       int warp_idx,
849       int lane_idx)
850       : Base(b_tile, thread_idx, warp_idx, lane_idx),
851         warp_tile_iterator_A1_(a, lane_idx),
852         warp_tile_iterator_A1_scale_(a_scale, lane_idx),
853         smem_iterator_B1_(b_tile, thread_idx),
854         prologue_done_(false) {
855     // Compute warp location within threadblock tile by mapping the warp_id to
856     // three coordinates:
857     //   _m: the warp's position within the threadblock along the M dimension
858     //   _n: the warp's position within the threadblock along the N dimension
859     //   _k: the warp's position within the threadblock along the K dimension
860     int warp_idx_mn_1 =
861         warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
862     int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
863     int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
864     int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
865 
866     // Add per-warp offsets in units of warp-level tiles
867     warp_tile_iterator_A1_.add_tile_offset(
868         {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
869     warp_tile_iterator_A1_scale_.add_tile_offset(
870         {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
871     this->warp_tile_iterator_B_.add_tile_offset(
872         {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
873   }
874 
875   /// Construct from tensor references
876   CUTLASS_DEVICE
MmaMultistageFromSharedMemory(typename Base::TensorRefA a,typename Base::TensorRefB b_tile,int thread_idx,int warp_idx,int lane_idx)877   MmaMultistageFromSharedMemory(
878       typename Base::TensorRefA a,
879       typename Base::TensorRefB b_tile,
880       ///< ID within the threadblock
881       int thread_idx,
882       ///< ID of warp
883       int warp_idx,
884       ///< ID of each thread within a warp
885       int lane_idx)
886       : Base(b_tile, thread_idx, warp_idx, lane_idx),
887         warp_tile_iterator_A1_(a, lane_idx),
888         smem_iterator_B1_(b_tile, thread_idx),
889         prologue_done_(false) {
890     // Compute warp location within threadblock tile by mapping the warp_id to
891     // three coordinates:
892     //   _m: the warp's position within the threadblock along the M dimension
893     //   _n: the warp's position within the threadblock along the N dimension
894     //   _k: the warp's position within the threadblock along the K dimension
895 
896     int warp_idx_mn_1 =
897         warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
898     int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
899 
900     int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
901     int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
902 
903     // Add per-warp offsets in units of warp-level tiles
904     warp_tile_iterator_A1_.add_tile_offset(
905         {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
906     this->warp_tile_iterator_B_.add_tile_offset(
907         {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
908   }
909 
910   CUTLASS_DEVICE
set_prologue_done(bool value)911   void set_prologue_done(bool value) {
912     prologue_done_ = value;
913   }
914 
915   CUTLASS_DEVICE
prologue(typename Base::SharedStorage & shared_storage,IteratorB iterator_B1,int thread_idx,int problem_size_0_n)916   static void prologue(
917       typename Base::SharedStorage& shared_storage,
918       IteratorB iterator_B1,
919       int thread_idx,
920       int problem_size_0_n) {
921     SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx);
922     _prologue(
923         iterator_B1,
924         (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK,
925         smem_iterator_B1);
926   }
927 
928   CUTLASS_DEVICE
929   void copy_tiles_and_advance_1(
930       IteratorB1& iterator_B1,
931       int group_start_B1 = 0) {
932     iterator_B1.set_iteration_index(
933         group_start_B1 * IteratorB1::kAccessesPerVector);
934     this->smem_iterator_B1_.set_iteration_index(group_start_B1);
935 
936     // Load for operand B
937     CUTLASS_PRAGMA_UNROLL
938     for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
939       if (group_start_B1 + j < Detail::TBLoadIterationsB1) {
940         typename IteratorB1::AccessType* dst_ptr =
941             reinterpret_cast<typename IteratorB1::AccessType*>(
942                 this->smem_iterator_B1_.get());
943 
944         int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
945             IteratorB1::ThreadMap::kElementsPerAccess /
946             IteratorB1::kAccessesPerVector / 8;
947 
948         CUTLASS_PRAGMA_UNROLL
949         for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
950           auto gmem_ptr = iterator_B1.get();
951 
952           cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
953               dst_ptr + v, gmem_ptr, iterator_B1.valid());
954 
955           ++iterator_B1;
956         }
957         ++this->smem_iterator_B1_;
958       }
959     }
960   }
961 
962   CUTLASS_DEVICE
_prologue(IteratorB & iterator_B1,int32_t gemm_k_iterations_1,SmemIteratorB1 & smem_iterator_B1_)963   static void _prologue(
964       IteratorB& iterator_B1,
965       int32_t gemm_k_iterations_1,
966       SmemIteratorB1& smem_iterator_B1_) {
967     // Issue several complete stages
968     CUTLASS_PRAGMA_UNROLL
969     for (int stage = 0; stage < kNumStagesConcurrentLoad;
970          ++stage, --gemm_k_iterations_1) {
971       iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1);
972       iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
973 
974       iterator_B1.set_iteration_index(0);
975       smem_iterator_B1_.set_iteration_index(0);
976 
977       // Load for operand B
978       CUTLASS_PRAGMA_UNROLL
979       for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) {
980         typename IteratorB1::AccessType* dst_ptr =
981             reinterpret_cast<typename IteratorB1::AccessType*>(
982                 smem_iterator_B1_.get());
983 
984         CUTLASS_PRAGMA_UNROLL
985         for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
986           int const kSrcBytes =
987               sizeof_bits<typename IteratorB1::Element>::value *
988               IteratorB1::ThreadMap::kElementsPerAccess /
989               IteratorB1::kAccessesPerVector / 8;
990 
991           cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
992               dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
993 
994           ++iterator_B1;
995         }
996 
997         ++smem_iterator_B1_;
998       }
999 
1000       // Move to the next stage
1001       iterator_B1.add_tile_offset({1, 0});
1002 
1003       smem_iterator_B1_.add_tile_offset({1, 0});
1004 
1005       // Defines the boundary of a stage of cp.async.
1006       cutlass::arch::cp_async_fence();
1007     }
1008     iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1);
1009     iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
1010   }
1011 
1012   /// Perform a threadblock-scoped matrix multiply-accumulate
1013   CUTLASS_DEVICE
operator()1014   void operator()(
1015       ///< problem size of GEMM
1016       int gemm_k_iterations_1_,
1017       ///< destination accumulator tile
1018       FragmentC1& accum,
1019       ///< iterator over B1 operand in global memory
1020       IteratorB1 iterator_B1,
1021       ///< initial value of accumulator
1022       FragmentC1 const& src_accum) {
1023     // 2nd Gemm
1024 
1025     //
1026     // Prologue
1027     //
1028     // Perform accumulation in the 'd' output operand
1029     accum = src_accum;
1030 
1031     if (!prologue_done_) {
1032       _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_);
1033     } else if (!kSmemContainsEntireB) {
1034       // Restore the iterators increments
1035 
1036       int gemm_k_iterations_1 = gemm_k_iterations_1_;
1037       // Issue several complete stages
1038       CUTLASS_PRAGMA_UNROLL
1039       for (int stage = 0; stage < kNumStagesConcurrentLoad;
1040            ++stage, --gemm_k_iterations_1) {
1041         iterator_B1.set_iteration_index(0);
1042         this->smem_iterator_B1_.set_iteration_index(0);
1043 
1044         // Load for operand B
1045         CUTLASS_PRAGMA_UNROLL
1046         for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) {
1047           CUTLASS_PRAGMA_UNROLL
1048           for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
1049             ++iterator_B1;
1050           }
1051           ++this->smem_iterator_B1_;
1052         }
1053         iterator_B1.add_tile_offset({1, 0});
1054         this->smem_iterator_B1_.add_tile_offset({1, 0});
1055       }
1056       iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1);
1057       iterator_B1.clear_mask(gemm_k_iterations_1 <= 0);
1058     }
1059 
1060     // DEPBAR+SYNC
1061     cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
1062     __syncthreads();
1063 
1064     // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty
1065     // if scaling is disabled.
1066 
1067     // Pair of fragments used to overlap shared memory loads and math
1068     // instructions
1069     WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
1070     WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2];
1071     WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
1072     WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
1073     WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
1074 
1075     Operator1 warp_mma1;
1076 
1077     warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
1078     ++warp_tile_iterator_A1_;
1079 
1080     warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
1081     ++warp_tile_iterator_A1_scale_;
1082 
1083     this->warp_tile_iterator_B_.set_kgroup_index(0);
1084     this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]);
1085     ++this->warp_tile_iterator_B_;
1086 
1087     int smem_write_stage_idx = Base::kStages - 1;
1088     int smem_read_stage_idx = 0;
1089 
1090     warp_mma1.transform(
1091         warp_transformed_frag_A1[0],
1092         warp_transformed_frag_B1[0],
1093         FragmentAScaler::apply(
1094             warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]),
1095         warp_loaded_frag_B1[0]);
1096 
1097     // tf32x3 kernels use staging accumulation. warp_mma uses a temporary
1098     // accumulator and this temporary accumulator is added to the final
1099     // accumulator once in every mainloop iteration.
1100     plus<FragmentC1> plus_accum;
1101 
1102     FragmentC1 tmp_accum;
1103 
1104     if (platform::is_same<
1105             typename Operator1::MathOperator,
1106             arch::OpMultiplyAddFastF32>::value ||
1107         platform::is_same<
1108             typename Operator1::MathOperator,
1109             arch::OpMultiplyAddComplexFastF32>::value) {
1110       tmp_accum.clear();
1111     }
1112 
1113     //
1114     // Mainloop
1115     //
1116 
1117     CUTLASS_PRAGMA_UNROLL
1118     for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1);
1119          gemm_k_iterations_1 > (-Base::kStages + 1);
1120          gemm_k_iterations_1--) {
1121       //
1122       // Loop over GEMM K dimension
1123       //
1124 
1125       // Computes a warp-level GEMM on data held in shared memory
1126       // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
1127       CUTLASS_PRAGMA_UNROLL
1128       for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
1129            ++warp_mma_k) {
1130         // Load warp-level tile from accumulator fragment (A)
1131         // or shared memory (operand B)
1132         this->warp_tile_iterator_B_.set_kgroup_index(
1133             (warp_mma_k + 1) % Base::kWarpGemmIterations1);
1134         // skip warp tile loading for the last kgroup (we are out of the buf)
1135         if (gemm_k_iterations_1 > (-Base::kStages + 2) ||
1136             warp_mma_k < Base::kWarpGemmIterations1 - 1) {
1137           warp_tile_iterator_A1_.load(
1138               warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
1139           warp_tile_iterator_A1_scale_.load(
1140               warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
1141           this->warp_tile_iterator_B_.load(
1142               warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
1143         }
1144         ++warp_tile_iterator_A1_;
1145         ++warp_tile_iterator_A1_scale_;
1146         ++this->warp_tile_iterator_B_;
1147 
1148         if (warp_mma_k > 0)
1149           warp_mma1.transform(
1150               warp_transformed_frag_A1[warp_mma_k % 2],
1151               warp_transformed_frag_B1[warp_mma_k % 2],
1152               FragmentAScaler::apply(
1153                   warp_loaded_frag_A1[warp_mma_k % 2],
1154                   warp_loaded_frag_A1_scale[warp_mma_k % 2]),
1155               warp_loaded_frag_B1[warp_mma_k % 2]);
1156 
1157         if (platform::is_same<
1158                 typename Operator1::MathOperator,
1159                 arch::OpMultiplyAddFastF32>::value ||
1160             platform::is_same<
1161                 typename Operator1::MathOperator,
1162                 arch::OpMultiplyAddComplexFastF32>::value) {
1163           warp_mma1(
1164               tmp_accum,
1165               warp_transformed_frag_A1[warp_mma_k % 2],
1166               warp_transformed_frag_B1[warp_mma_k % 2],
1167               tmp_accum);
1168 
1169           if (warp_mma_k == 0) {
1170             accum = plus_accum(accum, tmp_accum);
1171             tmp_accum.clear();
1172           }
1173         } else {
1174           warp_mma1(
1175               accum,
1176               warp_transformed_frag_A1[warp_mma_k % 2],
1177               warp_transformed_frag_B1[warp_mma_k % 2],
1178               accum);
1179         }
1180 
1181         // Issue global->shared copies for the this stage
1182         if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
1183           int group_start_iteration_B1;
1184 
1185           group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
1186 
1187           if (!kSmemContainsEntireB) {
1188             copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
1189           }
1190         }
1191 
1192         if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
1193           int group_start_iteration_B1;
1194           group_start_iteration_B1 =
1195               (warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
1196 
1197           if (!kSmemContainsEntireB) {
1198             copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
1199           }
1200 
1201           // Inserts a memory fence between stages of cp.async instructions.
1202           cutlass::arch::cp_async_fence();
1203 
1204           // Waits until kStages-2 stages have committed.
1205           arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
1206           __syncthreads();
1207 
1208           // Move to the next stage
1209           iterator_B1.add_tile_offset({1, 0});
1210 
1211           this->smem_iterator_B1_.add_tile_offset({1, 0});
1212 
1213           // Add negative offsets to return iterators to the 'start' of the
1214           // circular buffer in shared memory
1215           if (!kSmemContainsEntireB) {
1216             if (smem_write_stage_idx == (Base::kStages - 1)) {
1217               this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
1218               smem_write_stage_idx = 0;
1219             } else {
1220               ++smem_write_stage_idx;
1221             }
1222 
1223             if (smem_read_stage_idx == (Base::kStages - 1)) {
1224               this->warp_tile_iterator_B_.add_tile_offset(
1225                   {-Base::kStages * Policy1::kPartitionsK *
1226                        Base::kWarpGemmIterations1,
1227                    0});
1228               smem_read_stage_idx = 0;
1229             } else {
1230               ++smem_read_stage_idx;
1231             }
1232           }
1233 
1234           iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2);
1235           iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
1236         }
1237 
1238         // Do any conversions feeding the first stage at the end of the loop so
1239         // we can start right away on mma instructions
1240         if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
1241           warp_mma1.transform(
1242               warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
1243               warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
1244               FragmentAScaler::apply(
1245                   warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
1246                   warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]),
1247               warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
1248       }
1249     }
1250 
1251     if (platform::is_same<
1252             typename Operator1::MathOperator,
1253             arch::OpMultiplyAddFastF32>::value ||
1254         platform::is_same<
1255             typename Operator1::MathOperator,
1256             arch::OpMultiplyAddComplexFastF32>::value) {
1257       accum = plus_accum(accum, tmp_accum);
1258     }
1259   }
1260 };
1261 
1262 // Converts a "regular" Mma into their counterpart from shared memory
1263 template <
1264     typename Mma_,
1265     int kMaxK,
1266     typename WarpIteratorA_,
1267     /// whether or not to apply elementwise multiplication of operand A by
1268     /// another matrix in shared memory before usage in A @ B
1269     bool kScaleOperandA,
1270     bool kTransposeA = false>
1271 struct DefaultMmaFromSharedMemory;
1272 
1273 // Mma pipelined
1274 template <
1275     /// Size of the Gemm problem - concept: gemm::GemmShape<>
1276     typename Shape_,
1277     /// Iterates over tiles of A operand in global memory
1278     //  (concept: ReadableTileIterator | ForwardTileIterator |
1279     //  MaskedTileIterator)
1280     typename IteratorA_,
1281     /// Iterates over tiles of A operand in shared memory
1282     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
1283     typename SmemIteratorA_,
1284     typename WarpIteratorA_,
1285     /// Iterates over tiles of B operand in global memory
1286     //  (concept: ReadableTileIterator | ForwardTileIterator |
1287     //  MaskedTileIterator)
1288     typename IteratorB_,
1289     /// Iterates over tiles of B operand in shared memory
1290     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
1291     typename SmemIteratorB_,
1292     /// Data type of accumulator matrix
1293     typename ElementC_,
1294     /// Data type of accumulator matrix
1295     typename LayoutC_,
1296     /// Policy describing tuning details (concept: MmaPolicy)
1297     typename Policy_,
1298     /// Transformation applied to A operand
1299     typename TransformA_,
1300     /// Transformation applied to B operand
1301     typename TransformB_,
1302     // Max MMA problem size K
1303     int kMaxK,
1304     /// whether or not to apply elementwise multiplication of operand A by
1305     /// another matrix in shared memory before usage in A @ B
1306     bool kScaleOperandA,
1307     bool kTransposeA>
1308 struct DefaultMmaFromSharedMemory<
1309     MmaPipelined<
1310         Shape_,
1311         IteratorA_,
1312         SmemIteratorA_,
1313         IteratorB_,
1314         SmemIteratorB_,
1315         ElementC_,
1316         LayoutC_,
1317         Policy_,
1318         TransformA_,
1319         TransformB_>,
1320     kMaxK,
1321     WarpIteratorA_,
1322     kScaleOperandA,
1323     kTransposeA> {
1324   using RegularMma = MmaPipelined<
1325       Shape_,
1326       IteratorA_,
1327       SmemIteratorA_,
1328       IteratorB_,
1329       SmemIteratorB_,
1330       ElementC_,
1331       LayoutC_,
1332       Policy_,
1333       TransformA_,
1334       TransformB_>;
1335 
1336   using WarpShape = typename Policy_::Operator::Shape;
1337   using InstructionShape = typename Policy_::Operator::InstructionShape;
1338   using ArchMmaOperator = typename Policy_::Operator;
1339 
1340   static constexpr bool kIsTransposedA = false;
1341   using WarpIteratorA = WarpIteratorA_;
1342   using IteratorB =
1343       typename cutlass::transform::threadblock::MakeIteratorResidualLast<
1344           IteratorB_>::Iterator;
1345 
1346   using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory<
1347       Shape_,
1348       WarpIteratorA,
1349       kScaleOperandA,
1350       kMaxK,
1351       IteratorB,
1352       SmemIteratorB_,
1353       ElementC_,
1354       LayoutC_,
1355       Policy_>;
1356 };
1357 
1358 template <
1359     /// Size of the Gemm problem - concept: gemm::GemmShape<>
1360     typename Shape_,
1361     /// Iterates over tiles of A operand in global memory
1362     //  (concept: ReadableTileIterator | ForwardTileIterator |
1363     //  MaskedTileIterator)
1364     typename IteratorA_,
1365     /// Iterates over tiles of A operand in shared memory
1366     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
1367     typename SmemIteratorA_,
1368     typename WarpIteratorA_,
1369     /// Cache operation for operand A
1370     cutlass::arch::CacheOperation::Kind CacheOpA,
1371     /// Iterates over tiles of B operand in global memory
1372     //  (concept: ReadableTileIterator | ForwardTileIterator |
1373     //  MaskedTileIterator)
1374     typename IteratorB_,
1375     /// Iterates over tiles of B operand in shared memory
1376     /// (concept: WriteableTileIterator | RandomAccessTileIterator)
1377     typename SmemIteratorB_,
1378     /// Cache operation for operand B
1379     cutlass::arch::CacheOperation::Kind CacheOpB,
1380     /// Data type of accumulator matrix
1381     typename ElementC_,
1382     /// Data type of accumulator matrix
1383     typename LayoutC_,
1384     /// Policy describing tuning details (concept: MmaPolicy)
1385     typename Policy_,
1386     /// Number of stages,
1387     int Stages,
1388     /// Use zfill or predicate for out-of-bound cp.async
1389     SharedMemoryClearOption SharedMemoryClear,
1390     int kMaxK,
1391     /// whether or not to apply elementwise multiplication of operand A by
1392     /// another matrix in shared memory before usage in A @ B
1393     bool kScaleOperandA,
1394     bool kTransposeA>
1395 struct DefaultMmaFromSharedMemory<
1396     MmaMultistage<
1397         Shape_,
1398         IteratorA_,
1399         SmemIteratorA_,
1400         CacheOpA,
1401         IteratorB_,
1402         SmemIteratorB_,
1403         CacheOpB,
1404         ElementC_,
1405         LayoutC_,
1406         Policy_,
1407         Stages,
1408         SharedMemoryClear>,
1409     kMaxK,
1410     WarpIteratorA_,
1411     kScaleOperandA,
1412     kTransposeA> {
1413   using RegularMma = MmaMultistage<
1414       Shape_,
1415       IteratorA_,
1416       SmemIteratorA_,
1417       CacheOpA,
1418       IteratorB_,
1419       SmemIteratorB_,
1420       CacheOpB,
1421       ElementC_,
1422       LayoutC_,
1423       Policy_,
1424       Stages,
1425       SharedMemoryClear>;
1426 
1427   using WarpShape = typename Policy_::Operator::Shape;
1428   using InstructionShape = typename Policy_::Operator::InstructionShape;
1429   using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
1430   static constexpr bool kIsTransposedA =
1431       WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
1432   using WarpIteratorA = typename platform::conditional<
1433       kIsTransposedA,
1434       typename WarpIteratorTranspose::Iterator,
1435       WarpIteratorA_>::type;
1436 
1437   // Reduce the number of stages if we don't need that many
1438   static int constexpr kStagesMax =
1439       (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
1440   static int constexpr kStages = cutlass::const_min(Stages, kStagesMax);
1441 
1442   using IteratorB =
1443       typename cutlass::transform::threadblock::MakeIteratorResidualLast<
1444           IteratorB_>::Iterator;
1445   using Mma =
1446       typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory<
1447           Shape_,
1448           WarpIteratorA,
1449           kScaleOperandA,
1450           IteratorB,
1451           SmemIteratorB_,
1452           RegularMma::kCacheOpB,
1453           ElementC_,
1454           LayoutC_,
1455           Policy_,
1456           kStages,
1457           kMaxK>;
1458 };
1459 
1460 /////////////////////////////////////////////////////////////////////////////////////////////////
1461 
1462 template <
1463     typename IteratorC,
1464     typename Operator,
1465     typename scalar_t,
1466     typename WarpShape_,
1467     typename ThreadblockShape_>
1468 struct B2bGemm;
1469 
1470 // Tensor Cores >= Sm75 specialization (Ampere ...)
1471 template < /// Size of the matrix to load (concept: MatrixShape)
1472     typename Shape_,
1473     /// Element type
1474     typename Element_,
1475     /// Layout of operand in memory
1476     typename Layout_,
1477     /// Shape of one matrix product operation (concept: MatrixShape)
1478     typename InstructionShape_,
1479     /// Interval between adjacent *MMA instructions (in units of MMA
1480     /// instructions, concept: MatrixShape)
1481     typename OpDelta_,
1482     typename Operator,
1483     typename scalar_t,
1484     typename WarpShape_,
1485     typename ThreadblockShape_>
1486 struct B2bGemm<
1487     cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
1488         Shape_,
1489         Element_,
1490         Layout_,
1491         InstructionShape_,
1492         OpDelta_>,
1493     Operator,
1494     scalar_t,
1495     WarpShape_,
1496     ThreadblockShape_> {
1497   using IteratorC =
1498       typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
1499           Shape_,
1500           Element_,
1501           Layout_,
1502           InstructionShape_,
1503           OpDelta_>;
1504   using FragmentC = typename IteratorC::Fragment;
1505   using InstructionShape = InstructionShape_;
1506   using WarpShape = WarpShape_;
1507   using ThreadblockShape = ThreadblockShape_;
1508   using accum_t = Element_;
1509   using lse_scalar_t = float;
1510 
1511   using SmemAccumulatorLayout = cutlass::layout::RowMajor;
1512 
1513   // Iterator to load accumulators (results of matmul in registers)
1514   using FragmentIteratorAccumulator =
1515       cutlass::epilogue::warp::FragmentIteratorTensorOp<
1516           WarpShape,
1517           InstructionShape,
1518           accum_t,
1519           typename Operator::Policy::Operator::FragmentC,
1520           cutlass::layout::RowMajor>;
1521 
1522   // Iterator to store to shared-memory
1523   using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp<
1524       WarpShape,
1525       InstructionShape,
1526       scalar_t, // accum_t,
1527       SmemAccumulatorLayout>;
1528   using AccumulatorSharedStorage =
1529       cutlass::gemm::threadblock::AccumulatorSharedStorage<
1530           ThreadblockShape,
1531           typename SmemIteratorD0::Element,
1532           typename SmemIteratorD0::TensorLayout,
1533           typename SmemIteratorD0::Padding>;
1534   // We need to provide an operation for the epilogue. Let's create an
1535   // operation that does nothing (ScaleType::Nothing), just converts
1536   // from accum_t (float) -> scalar_t (can be half)
1537   using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination<
1538       typename SmemIteratorD0::Element, // ElementOutput
1539       FragmentIteratorAccumulator::Fragment::kElements,
1540       accum_t, // ElementAccumulator
1541       typename SmemIteratorD0::Element, // ElementCompute
1542       cutlass::epilogue::thread::ScaleType::Nothing>;
1543   using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
1544       SmemIteratorD0,
1545       FragmentIteratorAccumulator,
1546       SmemIteratorD0, // ScaleBiasIterator - not used
1547       OutputOpNoOp>;
1548 
1549   // Epilogue 2: with LSE (for backwards pass)
1550   static int const kElementsPerAccess = 2; // TODO: Why 2?
1551   using IteratorAccumulatorLSE =
1552       cutlass::transform::threadblock::VectorIterator<
1553           cutlass::transform::threadblock::PredicatedVectorAccessIterator<
1554               // Shape
1555               cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kN>,
1556               // WarpShape
1557               cutlass::MatrixShape<WarpShape::kM, WarpShape::kN>,
1558               lse_scalar_t,
1559               cutlass::layout::RowMajor,
1560               kElementsPerAccess>>;
1561   using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp<
1562       scalar_t, // ElementOutput_
1563       lse_scalar_t, // ElementLSE_
1564       accum_t, // ElementAccumulator_
1565       accum_t, // ElementCompute_
1566       128 / cutlass::sizeof_bits<scalar_t>::value
1567       // FragmentIteratorAccumulator::Fragment::kElements
1568       // InstructionShape::kM * InstructionShape::kN / 32
1569       >;
1570   using EpilogueWithLSE =
1571       cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
1572           SmemIteratorD0,
1573           FragmentIteratorAccumulator,
1574           IteratorAccumulatorLSE,
1575           EpilogueOpApplyLSE>;
1576 
1577   static void CUTLASS_DEVICE accumToSmem(
1578       AccumulatorSharedStorage& shared_storage,
1579       FragmentC const& accum,
1580       int lane_id,
1581       cutlass::MatrixCoord const& tile_coords) {
1582     SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
1583     smem_iterator_attn.add_tile_offset(
1584         tile_coords *
1585         cutlass::MatrixCoord{
1586             SmemIteratorD0::TileIterations::kRow,
1587             SmemIteratorD0::TileIterations::kColumn});
1588     Epilogue epilogue;
1589     epilogue(OutputOpNoOp({}), smem_iterator_attn, accum);
1590   }
1591 
1592   static void CUTLASS_DEVICE accumApplyLSEToSmem(
1593       AccumulatorSharedStorage& shared_storage,
1594       FragmentC& accum,
1595       lse_scalar_t const* lse,
1596       int32_t lse_extents,
1597       int thread_id,
1598       int warp_id,
1599       int lane_id,
1600       cutlass::MatrixCoord const& tile_coords) {
1601     constexpr int32_t kAlignLSE = 32;
1602     IteratorAccumulatorLSE iterator_lse(
1603         lse,
1604         {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE},
1605         thread_id,
1606         warp_id,
1607         cutlass::MatrixCoord{0, 0} // offset
1608     );
1609 
1610     SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
1611     smem_iterator_attn.add_tile_offset(
1612         tile_coords *
1613         cutlass::MatrixCoord{
1614             SmemIteratorD0::TileIterations::kRow,
1615             SmemIteratorD0::TileIterations::kColumn});
1616     EpilogueWithLSE epilogue;
1617     EpilogueOpApplyLSE minus_lse_exp({});
1618     epilogue(
1619         minus_lse_exp,
1620         smem_iterator_attn,
1621         accum,
1622         // scale - unused
1623         iterator_lse,
1624         // bias
1625         iterator_lse);
1626   }
1627 };
1628 
1629 // Volta Specialization
1630 // only supported for f16
1631 template <typename Operator, typename WarpShape_, typename ThreadblockShape_>
1632 struct B2bGemm<
1633     cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
1634         cutlass::MatrixShape<32, 32>,
1635         float,
1636         cutlass::layout::RowMajor,
1637         cutlass::gemm::GemmShape<16, 16, 4>,
1638         cutlass::MatrixShape<1, 1>>,
1639     Operator,
1640     cutlass::half_t,
1641     WarpShape_,
1642     ThreadblockShape_> {
1643   using IteratorC =
1644       cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
1645           cutlass::MatrixShape<32, 32>,
1646           float,
1647           cutlass::layout::RowMajor,
1648           cutlass::gemm::GemmShape<16, 16, 4>,
1649           cutlass::MatrixShape<1, 1>>;
1650   using scalar_t = cutlass::half_t;
1651   using accum_t = IteratorC::Element;
1652   using WarpShape = WarpShape_;
1653   using ThreadblockShape = ThreadblockShape_;
1654   using FragmentC = IteratorC::Fragment;
1655   using lse_scalar_t = float;
1656 
1657   // Storage in shared-memory for Q.Kt
1658   using SmemAccumulatorLayout =
1659       cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
1660   using AccumulatorSharedStorage =
1661       cutlass::gemm::threadblock::AccumulatorSharedStorage<
1662           ThreadblockShape,
1663           scalar_t,
1664           SmemAccumulatorLayout,
1665           cutlass::MatrixShape<0, 0> // Padding
1666           >;
1667   using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
1668   using Policy = typename IteratorC::Policy;
1669   using Element = accum_t;
1670   // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields
1671   // Let's copy their values
1672   static int const kElementsPerPartial = 4;
1673   using EleShapePerPatial = typename cutlass::platform::conditional<
1674       cutlass::platform::is_same<Element, float>::value,
1675       cutlass::MatrixShape<2, 2>,
1676       cutlass::MatrixShape<1, 4>>::type;
1677   static int const kElementsPerMma = 8;
1678   static int const kAccumulatorPatials = 2;
1679   using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
1680 
1681   static void CUTLASS_DEVICE accumToSmem(
1682       AccumulatorSharedStorage& shared_storage,
1683       FragmentC const& accum,
1684       int lane_id,
1685       cutlass::MatrixCoord const& tile_coords) {
1686     // ctor - from MmaVoltaTensorOpAccumulatorTileIterator
1687     TensorRef ref_(shared_storage.accum_ref());
1688     int quad = (lane_id >> 2);
1689     int lane_in_quad = (lane_id & 3);
1690     int accum_m, accum_n;
1691 
1692     if (cutlass::platform::is_same<Element, float>::value) {
1693       // (quad[2],quad[0])+lane_in_quad[0]
1694       accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
1695       // (quad[1])+lane_in_quad[1]
1696       accum_n =
1697           ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
1698           (lane_in_quad & 2);
1699     } else {
1700       accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
1701           lane_in_quad; // (quad[2],quad[0])
1702       accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
1703     }
1704     cutlass::MatrixCoord lane_offset(accum_m, accum_n);
1705 
1706     // Tile offset
1707     ref_.add_coord_offset(
1708         tile_coords *
1709         cutlass::MatrixCoord(
1710             {IteratorC::Shape::kRow, IteratorC::Shape::kColumn}));
1711 
1712     using AccessType = cutlass::Array<scalar_t, EleShapePerPatial::kColumn>;
1713 
1714     // store - from MmaVoltaTensorOpAccumulatorTileIterator
1715     CUTLASS_PRAGMA_UNROLL
1716     for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
1717       CUTLASS_PRAGMA_UNROLL
1718       for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
1719         CUTLASS_PRAGMA_UNROLL
1720         for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
1721           CUTLASS_PRAGMA_UNROLL
1722           for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
1723             int mma_accum_start =
1724                 (((tile_n * Policy::TileIterations::kRow + tile_m) *
1725                       Policy::MmaIterations::kColumn +
1726                   mma_n) *
1727                      Policy::MmaIterations::kRow +
1728                  mma_m) *
1729                 kElementsPerMma;
1730 
1731             CUTLASS_PRAGMA_UNROLL
1732             for (int p = 0; p < kAccumulatorPatials; ++p) {
1733               CUTLASS_PRAGMA_UNROLL
1734               for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
1735                 int accum_m = tile_m * Policy::InterleavedTile::kRow +
1736                     mma_m * QuadShapePerPatialMma::kRow + m * 2;
1737                 int accum_n = tile_n * Policy::InterleavedTile::kColumn +
1738                     mma_n * QuadShapePerPatialMma::kColumn +
1739                     p * Policy::InterleavedTile::kColumn / 2;
1740                 int r = (accum_m + lane_offset.row());
1741                 AccessType to_store;
1742                 CUTLASS_PRAGMA_UNROLL
1743                 for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
1744                   int idx = mma_accum_start + p * kElementsPerPartial +
1745                       m * EleShapePerPatial::kColumn + n;
1746                   int c = (accum_n + n + lane_offset.column());
1747                   to_store[n] = scalar_t(accum[idx]);
1748                 }
1749                 int c = (accum_n + lane_offset.column());
1750                 assert(r < 32);
1751                 assert(c < 32);
1752                 *reinterpret_cast<AccessType*>(
1753                     ref_.data() + ref_.offset({r, c})) = to_store;
1754               }
1755             }
1756           }
1757         }
1758       }
1759     }
1760   }
1761 
1762   static void CUTLASS_DEVICE accumApplyLSEToSmem(
1763       AccumulatorSharedStorage& shared_storage,
1764       typename IteratorC::Fragment& accum,
1765       lse_scalar_t const* lse,
1766       int lse_extent,
1767       int thread_id,
1768       int warp_id,
1769       int lane_id,
1770       cutlass::MatrixCoord const& tile_coords) {
1771     // Non-optimized way to apply LSE to registers
1772     // NOTE: accum is attn.T
1773     // TODO: Optimize for each architecture
1774     static constexpr int WarpSize = 32;
1775     using AccumLambdaIterator =
1776         typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
1777             Iterator;
1778     auto lane_offset =
1779         AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
1780 
1781     cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
1782     lse_prefetched.clear();
1783     int rowIdx = 0;
1784     int colIdx = 0;
1785     AccumLambdaIterator::iterateRows(
1786         lane_offset,
1787         [&](int accum_m) {
1788           ++rowIdx;
1789           colIdx = 0;
1790         },
1791         [&](int accum_m, int accum_n, int idx) {
1792           if (rowIdx == 1) {
1793             lse_prefetched[colIdx] = accum_n < lse_extent
1794                 ? lse[accum_n]
1795                 : platform::numeric_limits<accum_t>::infinity();
1796           }
1797           accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
1798           ++colIdx;
1799         },
1800         [&](int accum_m) {});
1801     accumToSmem(shared_storage, accum, lane_id, tile_coords);
1802   }
1803 };
1804 
1805 // Simt Specialization
1806 // for f32 on Sm70-Sm75 and f16/f32 below
1807 
1808 template <
1809     typename Operator,
1810     typename OperatorPolicy,
1811     typename scalar_t,
1812     typename WarpShape_,
1813     typename ThreadblockShape_>
1814 struct B2bGemm<
1815     cutlass::gemm::warp::MmaSimtTileIterator<
1816         cutlass::MatrixShape<32, 32>,
1817         cutlass::gemm::Operand::kC,
1818         float,
1819         cutlass::layout::RowMajor,
1820         OperatorPolicy,
1821         1,
1822         1>,
1823     Operator,
1824     scalar_t,
1825     WarpShape_,
1826     ThreadblockShape_> {
1827   using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator<
1828       cutlass::MatrixShape<32, 32>,
1829       cutlass::gemm::Operand::kC,
1830       float,
1831       cutlass::layout::RowMajor,
1832       OperatorPolicy,
1833       1,
1834       1>;
1835   using accum_t = typename IteratorC::Element;
1836   using WarpShape = WarpShape_;
1837   using ThreadblockShape = ThreadblockShape_;
1838   using FragmentC = typename IteratorC::Fragment;
1839   using lse_scalar_t = float;
1840 
1841   // Storage in shared-memory for Q.Kt
1842   using AccumulatorSharedStorage =
1843       cutlass::gemm::threadblock::AccumulatorSharedStorage<
1844           ThreadblockShape,
1845           scalar_t,
1846           cutlass::layout::ColumnMajor,
1847           cutlass::MatrixShape<0, 0> // Padding
1848           >;
1849 
1850   static void CUTLASS_DEVICE accumToSmem(
1851       AccumulatorSharedStorage& shared_storage,
1852       FragmentC const& accum,
1853       int lane_id,
1854       cutlass::MatrixCoord const& tile_coords) {
1855     using Policy = typename IteratorC::Policy;
1856     using Element = typename IteratorC::Element;
1857     using Iterations = typename IteratorC::Iterations;
1858     using Delta = typename IteratorC::Delta;
1859 
1860     auto ref_ = shared_storage.accum_ref();
1861     // ctor - MmaSimtTileIterator
1862     // compute offset based on thread ID and lane layout
1863     typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1864 
1865     MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1866         MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
1867 
1868     ref_.add_coord_offset(lane_offset);
1869 
1870     // Tile offset
1871     ref_.add_coord_offset(
1872         tile_coords *
1873         cutlass::MatrixCoord(
1874             {IteratorC::Shape::kRow, IteratorC::Shape::kColumn}));
1875 
1876     // store - MmaSimtTileIterator
1877     CUTLASS_PRAGMA_UNROLL
1878     for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
1879       CUTLASS_PRAGMA_UNROLL
1880       for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
1881         CUTLASS_PRAGMA_UNROLL
1882         for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
1883           CUTLASS_PRAGMA_UNROLL
1884           for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
1885             int r =
1886                 Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) +
1887                 m;
1888             int c = mma_n * Delta::kColumn + n;
1889             int idx = n +
1890                 Policy::LaneMmaShape::kN *
1891                     (mma_n +
1892                      Iterations::kColumn *
1893                          (m + mma_m * Policy::LaneMmaShape::kM));
1894             ref_.at({r, c}) = scalar_t(accum[idx]);
1895           }
1896         }
1897       }
1898     }
1899   }
1900 
1901   static void CUTLASS_DEVICE accumApplyLSEToSmem(
1902       AccumulatorSharedStorage& shared_storage,
1903       typename IteratorC::Fragment& accum,
1904       lse_scalar_t const* lse,
1905       int lse_extent,
1906       int thread_id,
1907       int warp_id,
1908       int lane_id,
1909       cutlass::MatrixCoord const& tile_coords) {
1910     // Non-optimized way to apply LSE to registers
1911     // NOTE: accum is attn.T
1912     // TODO: Optimize for each architecture
1913     static constexpr int WarpSize = 32;
1914     using AccumLambdaIterator =
1915         typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
1916             Iterator;
1917     auto lane_offset =
1918         AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
1919 
1920     cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
1921     lse_prefetched.clear();
1922     int rowIdx = 0;
1923     int colIdx = 0;
1924     AccumLambdaIterator::iterateRows(
1925         lane_offset,
1926         [&](int accum_m) {
1927           ++rowIdx;
1928           colIdx = 0;
1929         },
1930         [&](int accum_m, int accum_n, int idx) {
1931           if (rowIdx == 1) {
1932             lse_prefetched[colIdx] = accum_n < lse_extent
1933                 ? lse[accum_n]
1934                 : platform::numeric_limits<accum_t>::infinity();
1935           }
1936           accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
1937           ++colIdx;
1938         },
1939         [&](int accum_m) {});
1940     accumToSmem(shared_storage, accum, lane_id, tile_coords);
1941   }
1942 };
1943 
1944 } // namespace threadblock
1945 } // namespace gemm
1946 } // namespace cutlass
1947 
1948 /////////////////////////////////////////////////////////////////////////////////////////////////
1949