xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 /*! \file
9     \brief Cutlass provides helper template functions to figure out the right
10    data structures to instantiate to run a GEMM with various parameters (see
11    `cutlass/gemm/threadblock/default_mma.h`). However, due to template
12    instantiation priority rules, it will only create an MmaMultiStage with
13    kStages=3 (otherwise creates an MmePipelined - which is not compatible with
14    FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
15    so we just copy-pasted some code from `default_mma.h` and
16    `default_mma_core.h` files and wrapped this template to allow our use case.
17 
18     This is really only for the FastF32 case - aka using TensorCores with fp32.
19 */
20 
21 #pragma once
22 
23 #include <cutlass/gemm/threadblock/default_mma.h>
24 #include <cutlass/gemm/threadblock/default_mma_core_simt.h>
25 #include <cutlass/gemm/threadblock/default_mma_core_sm70.h>
26 #include <cutlass/gemm/threadblock/default_mma_core_sm75.h>
27 #include <cutlass/gemm/threadblock/default_mma_core_sm80.h>
28 
29 namespace cutlass {
30 namespace gemm {
31 namespace threadblock {
32 
33 template <
34     /// Element type for A matrix operand
35     typename ElementA,
36     /// Layout type for A matrix operand
37     typename LayoutA,
38     /// Access granularity of A matrix in units of elements
39     int kAlignmentA,
40     /// Element type for B matrix operand
41     typename ElementB,
42     /// Layout type for B matrix operand
43     typename LayoutB,
44     /// Access granularity of B matrix in units of elements
45     int kAlignmentB,
46     /// Element type for internal accumulation
47     typename ElementAccumulator,
48     /// Layout type for C and D matrix operand
49     typename LayoutC,
50     /// Operator class tag
51     typename OperatorClass,
52     /// Tag indicating architecture to tune for
53     typename ArchTag,
54     /// Threadblock-level tile size (concept: GemmShape)
55     typename ThreadblockShape,
56     /// Warp-level tile size (concept: GemmShape)
57     typename WarpShape,
58     /// Instruction-level tile size (concept: GemmShape)
59     typename InstructionShape,
60     /// Number of stages used in the pipelined mainloop
61     int Stages,
62     /// Operation performed by GEMM
63     typename Operator,
64     typename Enable_ = void>
65 struct FindDefaultMma {
66   static constexpr bool AccumulatorsInRowMajor = false;
67   static constexpr SharedMemoryClearOption SharedMemoryClear =
68       SharedMemoryClearOption::kNone;
69   using DefaultMma = cutlass::gemm::threadblock::DefaultMma<
70       ElementA,
71       LayoutA,
72       kAlignmentA,
73       ElementB,
74       LayoutB,
75       kAlignmentB,
76       ElementAccumulator,
77       LayoutC,
78       OperatorClass,
79       ArchTag,
80       ThreadblockShape,
81       WarpShape,
82       InstructionShape,
83       Stages,
84       Operator,
85       AccumulatorsInRowMajor,
86       SharedMemoryClear>;
87 };
88 
89 /// Specialization for sm80 / FastF32 / multistage with kStages=2
90 template <
91     typename ElementA_,
92     /// Layout type for A matrix operand
93     typename LayoutA_,
94     /// Access granularity of A matrix in units of elements
95     int kAlignmentA,
96     typename ElementB_,
97     /// Layout type for B matrix operand
98     typename LayoutB_,
99     /// Access granularity of B matrix in units of elements
100     int kAlignmentB,
101     typename ElementAccumulator,
102     /// Threadblock-level tile size (concept: GemmShape)
103     typename ThreadblockShape,
104     /// Warp-level tile size (concept: GemmShape)
105     typename WarpShape,
106     /// Instruction-level tile size (concept: GemmShape)
107     typename InstructionShape,
108     int kStages,
109     typename Operator>
110 struct FindDefaultMma<
111     ElementA_,
112     LayoutA_,
113     kAlignmentA,
114     ElementB_,
115     LayoutB_,
116     kAlignmentB,
117     ElementAccumulator,
118     layout::RowMajor,
119     arch::OpClassTensorOp,
120     arch::Sm80,
121     ThreadblockShape,
122     WarpShape,
123     InstructionShape,
124     kStages,
125     Operator,
126     typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
127   using LayoutC = layout::RowMajor;
128   using OperatorClass = arch::OpClassTensorOp;
129   using ArchTag = arch::Sm80;
130 
131   using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<
132       ElementA_,
133       LayoutA_,
134       kAlignmentA,
135       ElementB_,
136       LayoutB_,
137       kAlignmentB,
138       ElementAccumulator,
139       LayoutC,
140       OperatorClass,
141       ArchTag,
142       ThreadblockShape,
143       WarpShape,
144       InstructionShape,
145       3,
146       Operator>;
147   struct DefaultMma : DefaultMma_ {
148     using MmaCore_ = typename DefaultMma_::MmaCore;
149     // Define the threadblock-scoped multistage matrix multiply
150     using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<
151         typename MmaCore_::Shape,
152         typename DefaultMma_::IteratorA,
153         typename MmaCore_::SmemIteratorA,
154         MmaCore_::kCacheOpA,
155         typename DefaultMma_::IteratorB,
156         typename MmaCore_::SmemIteratorB,
157         MmaCore_::kCacheOpB,
158         ElementAccumulator,
159         LayoutC,
160         typename MmaCore_::MmaPolicy,
161         kStages>;
162   };
163 };
164 
165 } // namespace threadblock
166 } // namespace gemm
167 } // namespace cutlass
168