xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 
32 /*! \file
33     \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
34 */
35 
36 #pragma once
37 
38 #include <cutlass/cutlass.h>
39 
40 #include <cutlass/arch/arch.h>
41 #include <cutlass/gemm/gemm.h>
42 #include <cutlass/matrix_coord.h>
43 #include <cutlass/semaphore.h>
44 
45 /////////////////////////////////////////////////////////////////////////////////////////////////
46 
47 namespace cutlass {
48 namespace gemm {
49 namespace kernel {
50 
51 /////////////////////////////////////////////////////////////////////////////////////////////////
52 
53 template<typename Mma_,                 ///! Threadblock-scoped matrix multiply-accumulate
54          typename Epilogue_,            ///! Epilogue
55          typename ThreadblockSwizzle_,  ///! Threadblock swizzling function
56          typename KernelArch,  ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
57                                /// arch.
58          bool SplitKSerial     ///! If true, code supporting split-K via serial reduction is enabled.
59          >
60 struct GemmFpAIntB {
61 
62     using Mma                       = Mma_;
63     using Epilogue                  = Epilogue_;
64     using EpilogueOutputOp          = typename Epilogue::OutputOp;
65     using ThreadblockSwizzle        = ThreadblockSwizzle_;
66     static bool const kSplitKSerial = SplitKSerial;
67 
68     using ElementA     = typename Mma::IteratorA::Element;
69     using LayoutA      = typename Mma::IteratorA::Layout;
70     using ElementB     = typename Mma::IteratorB::Element;
71     using LayoutB      = typename Mma::IteratorB::Element;
72     using ElementC     = typename Epilogue::OutputTileIterator::Element;
73     using LayoutC      = typename Mma::LayoutC;
74     using ElementScale = ElementC;
75 
76     static ComplexTransform const kTransformA = Mma::kTransformA;
77     static ComplexTransform const kTransformB = Mma::kTransformA;
78 
79     // Type definitions about the mainloop.
80     using Operator         = typename Mma::Operator;
81     using OperatorClass    = typename Mma::Operator::OperatorClass;
82     using ThreadblockShape = typename Mma::Shape;
83     using WarpShape        = typename Mma::Operator::Shape;
84     using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
85     using ArchTag          = typename Mma::ArchTag;
86 
87     static int const kStages     = Mma::kStages;
88     static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
89     static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
90     static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
91 
92     /// Warp count (concept: GemmShape)
93     using WarpCount               = typename Mma::WarpCount;
94     static int const kThreadCount = 32 * WarpCount::kCount;
95 
96     static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
97 
98     /// Parameters structure
99     struct Arguments {
100         GemmUniversalMode mode = GemmUniversalMode::kGemm;
101 
102         cutlass::gemm::GemmCoord                         problem_size;
103         typename Mma::IteratorA::TensorRef               ref_A;
104         typename Mma::IteratorB::TensorRef               ref_B;
105         typename Mma::IteratorScale::TensorRef           ref_scale;
106         typename Epilogue::OutputTileIterator::TensorRef ref_C;
107         typename Epilogue::OutputTileIterator::TensorRef ref_D;
108 
109         // Control serial split-k
110         int batch_count;
111 
112         typename EpilogueOutputOp::Params output_op;
113 
114         // For gather+scatter operations
115         int const* gather_A_indices;
116         int const* gather_B_indices;
117         int const* scatter_D_indices;
118 
119         // Included so we can use Gemm Universal
120         int batch_stride_D = 0;
121 
122         //
123         // Methods
124         //
125 
126         CUTLASS_HOST_DEVICE
ArgumentsGemmFpAIntB::Arguments127         Arguments() {}
128 
129         CUTLASS_HOST_DEVICE
130         Arguments(cutlass::gemm::GemmCoord const&                  problem_size,
131                   typename Mma::IteratorA::TensorRef               ref_A,
132                   typename Mma::IteratorB::TensorRef               ref_B,
133                   typename Mma::IteratorScale::TensorRef           ref_scale,
134                   typename Epilogue::OutputTileIterator::TensorRef ref_C,
135                   typename Epilogue::OutputTileIterator::TensorRef ref_D,
136                   int                                              serial_split_k_factor,
137                   typename EpilogueOutputOp::Params                output_op = typename EpilogueOutputOp::Params(),
138                   int const*                                       gather_A_indices  = nullptr,
139                   int const*                                       gather_B_indices  = nullptr,
140                   int const*                                       scatter_D_indices = nullptr):
problem_sizeGemmFpAIntB::Arguments141             problem_size(problem_size),
142             ref_A(ref_A),
143             ref_B(ref_B),
144             ref_scale(ref_scale),
145             ref_C(ref_C),
146             ref_D(ref_D),
147             batch_count(serial_split_k_factor),
148             output_op(output_op),
149             gather_A_indices(gather_A_indices),
150             gather_B_indices(gather_B_indices),
151             scatter_D_indices(scatter_D_indices)
152         {
153         }
154     };
155 
156     /// Parameters structure
157     struct Params
158     {
159         cutlass::gemm::GemmCoord                         problem_size;
160         cutlass::gemm::GemmCoord                         grid_tiled_shape;
161         int                                              swizzle_log_tile;
162         typename Mma::IteratorA::Params                  params_A;
163         typename Mma::IteratorA::TensorRef               ref_A;
164         typename Mma::IteratorB::Params                  params_B;
165         typename Mma::IteratorB::TensorRef               ref_B;
166         typename Mma::IteratorScale::Params              params_scale;
167         typename Mma::IteratorScale::TensorRef           ref_scale;
168         typename Epilogue::OutputTileIterator::Params    params_C;
169         typename Epilogue::OutputTileIterator::TensorRef ref_C;
170         typename Epilogue::OutputTileIterator::Params    params_D;
171         typename Epilogue::OutputTileIterator::TensorRef ref_D;
172         typename EpilogueOutputOp::Params                output_op;
173         int*                                             semaphore;
174         int                                              gemm_k_size;
175         // For gather+scatter operations
176         int const* gather_A_indices;
177         int const* gather_B_indices;
178         int const* scatter_D_indices;
179 
180         //
181         // Methods
182         //
183 
ParamsGemmFpAIntB::Params184         Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {}
185 
ParamsGemmFpAIntB::Params186         Params(Arguments const&                args,
187                int                             device_sms,
188                int                             sm_occupancy):
189             problem_size(args.problem_size),
190             swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
191             params_A(args.ref_A.layout()),
192             ref_A(args.ref_A),
193             params_B(args.ref_B.layout()),
194             ref_B(args.ref_B),
195             params_scale(args.ref_scale.layout()),
196             ref_scale(args.ref_scale),
197             params_C(args.ref_C.layout()),
198             ref_C(args.ref_C),
199             params_D(args.ref_D.layout()),
200             ref_D(args.ref_D),
201             output_op(args.output_op),
202             gather_A_indices(args.gather_A_indices),
203             gather_B_indices(args.gather_B_indices),
204             scatter_D_indices(args.scatter_D_indices)
205         {
206             ThreadblockSwizzle swizzle;
207             grid_tiled_shape = swizzle.get_tiled_shape(
208                 args.problem_size,
209                 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
210                 args.batch_count);
211 
212             gemm_k_size = args.problem_size.k();
213         }
214 
get_workspace_sizeGemmFpAIntB::Params215         size_t get_workspace_size() const
216         {
217             return 0;
218         }
219 
220         Status init_workspace(void *workspace,cudaStream_t stream = nullptr)
221         {
222             return Status::kSuccess;
223         }
224 
get_grid_dimsGemmFpAIntB::Params225         dim3 get_grid_dims() const
226         {
227             return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape);
228         }
229     };
230 
231     /// Shared memory storage structure
232     union SharedStorage {
233         typename Mma::SharedStorage      main_loop;
234         typename Epilogue::SharedStorage epilogue;
235     };
236 
237     //
238     // Methods
239     //
240 
241     CUTLASS_HOST_DEVICE
GemmFpAIntBGemmFpAIntB242     GemmFpAIntB() {}
243 
244     /// Determines whether kernel satisfies alignment
245     CUTLASS_HOST_DEVICE
can_implementGemmFpAIntB246     static Status can_implement(Arguments const& args)
247     {
248 
249         static int const kAlignmentA =
250             (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ?
251                 32 :
252             (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value) ?
253                 64 :
254                 Mma::IteratorA::AccessType::kElements;
255         static int const kAlignmentB =
256             (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ?
257                 32 :
258             (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value) ?
259                 64 :
260                 Mma::IteratorB::AccessType::kElements;
261 
262         static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
263 
264         static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
265                                                           layout::ColumnMajorInterleaved<32>>::value) ?
266                                            32 :
267                                        (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
268                                                           layout::ColumnMajorInterleaved<64>>::value) ?
269                                            64 :
270                                            Epilogue::OutputTileIterator::kElementsPerAccess;
271 
272         if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
273             return Status::kErrorMisalignedOperand;
274         }
275 
276         if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
277             return Status::kErrorMisalignedOperand;
278         }
279 
280         if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
281             return Status::kErrorMisalignedOperand;
282         }
283 
284         if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
285             return Status::kErrorMisalignedOperand;
286         }
287 
288         if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
289             return Status::kErrorMisalignedOperand;
290         }
291 
292         return Status::kSuccess;
293     }
294 
295     // The dummy template parameter is not used and exists so that we can compile this code using
296     // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
297     // a namespace
298     template<bool B, typename dummy = void>
299     struct KernelRunner {
300         CUTLASS_DEVICE
run_kernelGemmFpAIntB::KernelRunner301         static void run_kernel(Params const& params, SharedStorage& shared_storage)
302         {
303             CUTLASS_NOT_IMPLEMENTED();
304         }
305     };
306 
307     template<typename dummy>
308     struct KernelRunner<true, dummy> {
309         CUTLASS_DEVICE
310         static void run_kernel(Params const& params, SharedStorage& shared_storage)
311         {
312             using LayoutB = typename Mma::IteratorB::Layout;
313             static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
314                               || platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
315                           "B must be row major/col major OR col major interleaved.");
316 
317             // Compute threadblock location
318             ThreadblockSwizzle threadblock_swizzle;
319 
320             cutlass::gemm::GemmCoord threadblock_tile_offset =
321                 threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
322 
323             // Early exit if CTA is out of range
324             if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
325                 || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
326 
327                 return;
328             }
329 
330             // Compute initial location in logical coordinates
331             cutlass::MatrixCoord tb_offset_A{
332                 threadblock_tile_offset.m() * Mma::Shape::kM,
333                 threadblock_tile_offset.k() * params.gemm_k_size,
334             };
335 
336             cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
337                                              threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
338 
339             cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN};
340 
341             // Problem size is a function of threadblock index in the K dimension
342             int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
343 
344             // Compute threadblock-scoped matrix multiply-add
345             int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
346 
347             // Compute position within threadblock
348             int thread_idx = threadIdx.x;
349 
350             // Construct iterators to A and B operands
351             typename Mma::IteratorA iterator_A(params.params_A,
352                                                params.ref_A.data(),
353                                                {params.problem_size.m(), problem_size_k},
354                                                thread_idx,
355                                                tb_offset_A,
356                                                params.gather_A_indices);
357 
358             typename Mma::IteratorB iterator_B(params.params_B,
359                                                params.ref_B.data(),
360                                                {problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
361                                                thread_idx,
362                                                tb_offset_B,
363                                                params.gather_B_indices);
364 
365             typename Mma::IteratorScale iterator_scale(params.params_scale,
366                                                        params.ref_scale.data(),
367                                                        {1, params.problem_size.n()},
368                                                        thread_idx,
369                                                        tb_offset_scale);
370 
371             // Broadcast the warp_id computed by lane 0 to ensure dependent code
372             // is compiled as warp-uniform.
373             int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
374             int lane_idx = threadIdx.x % 32;
375 
376             //
377             // Main loop
378             //
379             // Construct thread-scoped matrix multiply
380             Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
381 
382             typename Mma::FragmentC accumulators;
383 
384             accumulators.clear();
385 
386             if (!kSplitKSerial || gemm_k_iterations > 0) {
387                 // Compute threadblock-scoped matrix multiply-add
388                 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
389             }
390 
391             //
392             // Epilogue
393             //
394 
395             EpilogueOutputOp output_op(params.output_op);
396 
397             //
398             // Masked tile iterators constructed from members
399             //
400 
401             threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
402 
403             // assume identity swizzle
404             MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
405                                            threadblock_tile_offset.n() * Mma::Shape::kN);
406 
407             int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
408 
409             // Construct the semaphore.
410             Semaphore semaphore(params.semaphore + block_idx, thread_idx);
411 
412             // If performing a reduction via split-K, fetch the initial synchronization
413             if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
414 
415                 // Fetch the synchronization lock initially but do not block.
416                 semaphore.fetch();
417 
418                 // Indicate which position in a serial reduction the output operator is currently updating
419                 output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
420             }
421 
422             // Tile iterator loading from source tensor.
423             typename Epilogue::OutputTileIterator iterator_C(params.params_C,
424                                                              params.ref_C.data(),
425                                                              params.problem_size.mn(),
426                                                              thread_idx,
427                                                              threadblock_offset,
428                                                              params.scatter_D_indices);
429 
430             // Tile iterator writing to destination tensor.
431             typename Epilogue::OutputTileIterator iterator_D(params.params_D,
432                                                              params.ref_D.data(),
433                                                              params.problem_size.mn(),
434                                                              thread_idx,
435                                                              threadblock_offset,
436                                                              params.scatter_D_indices);
437 
438             Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
439 
440             // Wait on the semaphore - this latency may have been covered by iterator construction
441             if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
442 
443                 // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
444                 if (threadblock_tile_offset.k()) {
445                     iterator_C = iterator_D;
446                 }
447 
448                 semaphore.wait(threadblock_tile_offset.k());
449             }
450 
451             // Execute the epilogue operator to update the destination tensor.
452             epilogue(output_op, iterator_D, accumulators, iterator_C);
453 
454             //
455             // Release the semaphore
456             //
457 
458             if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
459 
460                 int lock = 0;
461                 if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
462 
463                     // The final threadblock resets the semaphore for subsequent grids.
464                     lock = 0;
465                 }
466                 else {
467                     // Otherwise, the semaphore is incremented
468                     lock = threadblock_tile_offset.k() + 1;
469                 }
470 
471                 semaphore.release(lock);
472             }
473         }
474     };
475 
476     CUTLASS_DEVICE
477     static void invoke(Params const &params, SharedStorage &shared_storage)
478     {
479         GemmFpAIntB op;
480         op(params, shared_storage);
481     }
482 
483     /*
484         To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
485         to the ArchTag of the cutlass kernel operator.
486       */
487     /// Executes one GEMM
488     CUTLASS_DEVICE
489     void operator()(Params const& params, SharedStorage& shared_storage)
490     {
491 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
492         static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
493         KernelRunner<compile_needed>::run_kernel(params, shared_storage);
494 #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
495         static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
496         KernelRunner<compile_needed>::run_kernel(params, shared_storage);
497 #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
498         static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
499         KernelRunner<compile_needed>::run_kernel(params, shared_storage);
500 #else
501         CUTLASS_NOT_IMPLEMENTED();
502 #endif
503     }
504 };
505 
506 /////////////////////////////////////////////////////////////////////////////////////////////////
507 
508 }  // namespace kernel
509 }  // namespace gemm
510 }  // namespace cutlass
511