1 /***************************************************************************************************
2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 * SPDX-License-Identifier: BSD-3-Clause
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice, this
9 * list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright notice,
12 * this list of conditions and the following disclaimer in the documentation
13 * and/or other materials provided with the distribution.
14 *
15 * 3. Neither the name of the copyright holder nor the names of its
16 * contributors may be used to endorse or promote products derived from
17 * this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *
30 **************************************************************************************************/
31 /*! \file
32 \brief Template for a double-buffered threadblock-scoped GEMM kernel.
33 */
34
35 #pragma once
36
37 #include <cutlass/aligned_buffer.h>
38 #include <cutlass/arch/memory.h>
39 #include <cutlass/array.h>
40 #include <cutlass/cutlass.h>
41 #include <cutlass/gemm/gemm.h>
42 #include <cutlass/gemm/threadblock/mma_base.h>
43 #include <cutlass/matrix_shape.h>
44 #include <cutlass/numeric_types.h>
45
46 ////////////////////////////////////////////////////////////////////////////////
47
48 namespace cutlass {
49 namespace gemm {
50 namespace threadblock {
51
52 ////////////////////////////////////////////////////////////////////////////////
53 // SFINAE trick so I can keep the same loop code for Volta and dispatch to the
54 // correct warp level mma. On volta, all data is stored to shared memory as FP16.
55 template<typename WarpMma, int kExpansionFactor = 1>
run_warp_mma(WarpMma & warp_mma,typename WarpMma::FragmentC & D,typename WarpMma::FragmentA const & A,typename WarpMma::FragmentB const & B,typename WarpMma::FragmentC const & C,const int warp_tileB_k_offset)56 CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
57 typename WarpMma::FragmentC& D,
58 typename WarpMma::FragmentA const& A,
59 typename WarpMma::FragmentB const& B,
60 typename WarpMma::FragmentC const& C,
61 const int warp_tileB_k_offset)
62 {
63 warp_mma(D, A, B, C);
64 }
65
66 template<typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
run_warp_mma(WarpMma & warp_mma,typename WarpMma::FragmentC & D,typename WarpMma::TransformedFragmentA const & A,typename WarpMma::TransformedFragmentB const & B,typename WarpMma::FragmentC const & C,const int warp_tileB_k_offset)67 CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
68 typename WarpMma::FragmentC& D,
69 typename WarpMma::TransformedFragmentA const& A,
70 typename WarpMma::TransformedFragmentB const& B,
71 typename WarpMma::FragmentC const& C,
72 const int warp_tileB_k_offset)
73 {
74 warp_mma(D, A, B, C, warp_tileB_k_offset);
75 }
76 ////////////////////////////////////////////////////////////////////////////////
77
78 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
79 /// instructions.
80 template<
81 /// Size of the Gemm problem - concept: gemm::GemmShape<>
82 typename Shape_,
83 /// Policy describing tuning details (concept: MmaPolicy)
84 typename Policy_,
85 /// The type of the scales
86 typename ElementScale_,
87 /// Number of stages,
88 int Stages,
89 /// Used for partial specialization
90 typename Enable = bool>
91 class DqMmaBase {
92 public:
93 ///< Size of the Gemm problem - concept: gemm::GemmShape<>
94 using Shape = Shape_;
95
96 ///< Policy describing tuning details
97 using Policy = Policy_;
98
99 ///< Type of the scale to be loaded
100 using ElementScale = ElementScale_;
101
102 //
103 // Dependent types
104 //
105
106 /// Warp-level Mma
107 using Operator = typename Policy::Operator;
108
109 /// Shape describing the overall GEMM computed from shared memory
110 /// by each warp.
111 using WarpGemm = typename Policy::Operator::Shape;
112
113 /// Shape describing the number of warps filling the CTA
114 using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
115
116 /// Number of warp-level GEMM operations
117 static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
118
119 static constexpr int kNumKIterationsPerWarpBLoad =
120 Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
121
122 static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
123 static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
124
125 /// Number of stages
126 static int const kStages = Stages;
127
128 /// Tensor reference to the A operand
129 using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
130
131 /// Tensor reference to the B operand
132 using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
133
134 //
135 // Nested structs
136 //
137
138 /// Shared storage object needed by threadblock-scoped GEMM
139 class SharedStorage {
140 public:
141 //
142 // Type definitions
143 //
144
145 /// Shape of the A matrix operand in shared memory
146 using ShapeA =
147 MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
148
149 /// Shape of the B matrix operand in shared memory
150 using ShapeB =
151 MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
152
153 public:
154 //
155 // Data members
156 //
157
158 /// Buffer for A operand
159 AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
160
161 /// Buffer for B operand
162 AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
163
164 /// Buffer to hold scales for threadblock
165 AlignedBuffer<ElementScale, Shape::kN> operand_scale;
166
167 public:
168 //
169 // Methods
170 //
171
172 /// Returns a layout object for the A matrix
173 CUTLASS_DEVICE
LayoutA()174 static typename Operator::LayoutA LayoutA()
175 {
176 return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
177 }
178
179 /// Returns a layout object for the B matrix
180 CUTLASS_HOST_DEVICE
LayoutB()181 static typename Operator::LayoutB LayoutB()
182 {
183 return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
184 }
185
186 /// Returns a TensorRef to the A operand
187 CUTLASS_HOST_DEVICE
operand_A_ref()188 TensorRefA operand_A_ref()
189 {
190 return TensorRefA{operand_A.data(), LayoutA()};
191 }
192
193 /// Returns a TensorRef to the B operand
194 CUTLASS_HOST_DEVICE
operand_B_ref()195 TensorRefB operand_B_ref()
196 {
197 return TensorRefB{operand_B.data(), LayoutB()};
198 }
199 };
200
201 protected:
202 //
203 // Data members
204 //
205
206 /// Iterator to load a warp-scoped tile of A operand from shared memory
207 typename Operator::IteratorA warp_tile_iterator_A_;
208
209 /// Iterator to load a warp-scoped tile of B operand from shared memory
210 typename Operator::IteratorB warp_tile_iterator_B_;
211
212 public:
213 /// Construct from tensor references
214 CUTLASS_DEVICE
DqMmaBase(SharedStorage & shared_storage,int thread_idx,int warp_idx,int lane_idx)215 DqMmaBase(
216 ///< Shared storage needed for internal use by threadblock-scoped GEMM
217 SharedStorage& shared_storage,
218 ///< ID within the threadblock
219 int thread_idx,
220 ///< ID of warp
221 int warp_idx,
222 ///< ID of each thread within a warp
223 int lane_idx):
224 warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
225 warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
226 {
227 }
228 };
229
230 /////////////////////////////////////////////////////////////////////////////////////////////////
231
232 } // namespace threadblock
233 } // namespace gemm
234 } // namespace cutlass
235
236 /////////////////////////////////////////////////////////////////////////////////////////////////
237