xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 /*! \file
32     \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33 */
34 
35 #pragma once
36 
37 #include <cutlass/aligned_buffer.h>
38 #include <cutlass/arch/memory.h>
39 #include <cutlass/array.h>
40 #include <cutlass/cutlass.h>
41 #include <cutlass/gemm/gemm.h>
42 #include <cutlass/gemm/threadblock/mma_base.h>
43 #include <cutlass/matrix_shape.h>
44 #include <cutlass/numeric_types.h>
45 
46 ////////////////////////////////////////////////////////////////////////////////
47 
48 namespace cutlass {
49 namespace gemm {
50 namespace threadblock {
51 
52 ////////////////////////////////////////////////////////////////////////////////
53 // SFINAE trick so I can keep the same loop code for Volta and dispatch to the
54 // correct warp level mma. On volta, all data is stored to shared memory as FP16.
55 template<typename WarpMma, int kExpansionFactor = 1>
run_warp_mma(WarpMma & warp_mma,typename WarpMma::FragmentC & D,typename WarpMma::FragmentA const & A,typename WarpMma::FragmentB const & B,typename WarpMma::FragmentC const & C,const int warp_tileB_k_offset)56 CUTLASS_DEVICE void run_warp_mma(WarpMma&                           warp_mma,
57                                  typename WarpMma::FragmentC&       D,
58                                  typename WarpMma::FragmentA const& A,
59                                  typename WarpMma::FragmentB const& B,
60                                  typename WarpMma::FragmentC const& C,
61                                  const int                          warp_tileB_k_offset)
62 {
63     warp_mma(D, A, B, C);
64 }
65 
66 template<typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
run_warp_mma(WarpMma & warp_mma,typename WarpMma::FragmentC & D,typename WarpMma::TransformedFragmentA const & A,typename WarpMma::TransformedFragmentB const & B,typename WarpMma::FragmentC const & C,const int warp_tileB_k_offset)67 CUTLASS_DEVICE void run_warp_mma(WarpMma&                                      warp_mma,
68                                  typename WarpMma::FragmentC&                  D,
69                                  typename WarpMma::TransformedFragmentA const& A,
70                                  typename WarpMma::TransformedFragmentB const& B,
71                                  typename WarpMma::FragmentC const&            C,
72                                  const int                                     warp_tileB_k_offset)
73 {
74     warp_mma(D, A, B, C, warp_tileB_k_offset);
75 }
76 ////////////////////////////////////////////////////////////////////////////////
77 
78 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
79 /// instructions.
80 template<
81     /// Size of the Gemm problem - concept: gemm::GemmShape<>
82     typename Shape_,
83     /// Policy describing tuning details (concept: MmaPolicy)
84     typename Policy_,
85     /// The type of the scales
86     typename ElementScale_,
87     /// Number of stages,
88     int Stages,
89     /// Used for partial specialization
90     typename Enable = bool>
91 class DqMmaBase {
92 public:
93     ///< Size of the Gemm problem - concept: gemm::GemmShape<>
94     using Shape = Shape_;
95 
96     ///< Policy describing tuning details
97     using Policy = Policy_;
98 
99     ///< Type of the scale to be loaded
100     using ElementScale = ElementScale_;
101 
102     //
103     // Dependent types
104     //
105 
106     /// Warp-level Mma
107     using Operator = typename Policy::Operator;
108 
109     /// Shape describing the overall GEMM computed from shared memory
110     /// by each warp.
111     using WarpGemm = typename Policy::Operator::Shape;
112 
113     /// Shape describing the number of warps filling the CTA
114     using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
115 
116     /// Number of warp-level GEMM operations
117     static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
118 
119     static constexpr int kNumKIterationsPerWarpBLoad =
120         Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
121 
122     static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
123     static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
124 
125     /// Number of stages
126     static int const kStages = Stages;
127 
128     /// Tensor reference to the A operand
129     using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
130 
131     /// Tensor reference to the B operand
132     using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
133 
134     //
135     // Nested structs
136     //
137 
138     /// Shared storage object needed by threadblock-scoped GEMM
139     class SharedStorage {
140     public:
141         //
142         // Type definitions
143         //
144 
145         /// Shape of the A matrix operand in shared memory
146         using ShapeA =
147             MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
148 
149         /// Shape of the B matrix operand in shared memory
150         using ShapeB =
151             MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
152 
153     public:
154         //
155         // Data members
156         //
157 
158         /// Buffer for A operand
159         AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
160 
161         /// Buffer for B operand
162         AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
163 
164         /// Buffer to hold scales for threadblock
165         AlignedBuffer<ElementScale, Shape::kN> operand_scale;
166 
167     public:
168         //
169         // Methods
170         //
171 
172         /// Returns a layout object for the A matrix
173         CUTLASS_DEVICE
LayoutA()174         static typename Operator::LayoutA LayoutA()
175         {
176             return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
177         }
178 
179         /// Returns a layout object for the B matrix
180         CUTLASS_HOST_DEVICE
LayoutB()181         static typename Operator::LayoutB LayoutB()
182         {
183             return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
184         }
185 
186         /// Returns a TensorRef to the A operand
187         CUTLASS_HOST_DEVICE
operand_A_ref()188         TensorRefA operand_A_ref()
189         {
190             return TensorRefA{operand_A.data(), LayoutA()};
191         }
192 
193         /// Returns a TensorRef to the B operand
194         CUTLASS_HOST_DEVICE
operand_B_ref()195         TensorRefB operand_B_ref()
196         {
197             return TensorRefB{operand_B.data(), LayoutB()};
198         }
199     };
200 
201 protected:
202     //
203     // Data members
204     //
205 
206     /// Iterator to load a warp-scoped tile of A operand from shared memory
207     typename Operator::IteratorA warp_tile_iterator_A_;
208 
209     /// Iterator to load a warp-scoped tile of B operand from shared memory
210     typename Operator::IteratorB warp_tile_iterator_B_;
211 
212 public:
213     /// Construct from tensor references
214     CUTLASS_DEVICE
DqMmaBase(SharedStorage & shared_storage,int thread_idx,int warp_idx,int lane_idx)215     DqMmaBase(
216         ///< Shared storage needed for internal use by threadblock-scoped GEMM
217         SharedStorage& shared_storage,
218         ///< ID within the threadblock
219         int thread_idx,
220         ///< ID of warp
221         int warp_idx,
222         ///< ID of each thread within a warp
223         int lane_idx):
224         warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
225         warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
226     {
227     }
228 };
229 
230 /////////////////////////////////////////////////////////////////////////////////////////////////
231 
232 }  // namespace threadblock
233 }  // namespace gemm
234 }  // namespace cutlass
235 
236 /////////////////////////////////////////////////////////////////////////////////////////////////
237