xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/gemm/warp/default_mma_tensor_op.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3  * SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice, this
9  * list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *
30  **************************************************************************************************/
31 /*! \file
32     \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
33 */
34 
35 #pragma once
36 
37 #include <cutlass/cutlass.h>
38 #include <cutlass/gemm/warp/default_mma_tensor_op.h>
39 #include <cutlass/gemm/warp/mma_tensor_op.h>
40 
41 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h>
42 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h>
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace warp {
47 
48 /////////////////////////////////////////////////////////////////////////////////////////////////
49 
50 /// Partial specialization for m-by-n-by-kgroup
51 template<
52     /// Shape of one matrix production operation (concept: GemmShape)
53     typename WarpShape_,
54     /// Shape of one matrix production operation (concept: GemmShape)
55     typename InstructionShape_,
56     /// Data type of A elements,
57     typename ElementA,
58     /// Layout of A matrix (concept: MatrixLayout)
59     typename LayoutA,
60     /// Data type of B elements
61     typename ElementB,
62     /// Layout of B matrix (concept: MatrixLayout)
63     typename LayoutB,
64     /// Element type of C matrix
65     typename ElementC,
66     /// Layout of C matrix (concept: MatrixLayout)
67     typename LayoutC,
68     /// Number of partitions along K dimension
69     int PartitionsK,
70     /// Store the accumulators in row major or column major.  Row major is used
71     /// when output layout is interleaved.
72     bool AccumulatorsInRowMajor>
73 struct DefaultMmaTensorOp<WarpShape_,
74                           InstructionShape_,
75                           ElementA,
76                           LayoutA,
77                           ElementB,
78                           LayoutB,
79                           ElementC,
80                           LayoutC,
81                           arch::OpMultiplyAddDequantizeInterleavedBToA,
82                           PartitionsK,
83                           AccumulatorsInRowMajor> {
84 
85 private:
86     // Shape for computing the FP16s
87     using ComputeInstructionShape = InstructionShape_;
88 
89     // Chosen so we get K=16 for int8 and K=32 for int4.
90     static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
91 
92     // Shape for loading the narrow data type from shared memory
93     using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
94 
95 public:
96     using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<InstructionShape_,
97                                                                              32,
98                                                                              ElementA,
99                                                                              cutlass::layout::RowMajor,
100                                                                              ElementA,
101                                                                              cutlass::layout::ColumnMajor,
102                                                                              ElementC,
103                                                                              cutlass::layout::RowMajor,
104                                                                              arch::OpMultiplyAdd>,
105                                                           cutlass::MatrixShape<1, 1>>;
106 
107     // Define the warp-level tensor op
108     using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
109                                                                  ElementA,
110                                                                  LayoutA,
111                                                                  ElementB,
112                                                                  LayoutB,
113                                                                  ElementC,
114                                                                  LayoutC,
115                                                                  Policy,
116                                                                  LoadInstructionShape,
117                                                                  PartitionsK,
118                                                                  AccumulatorsInRowMajor>;
119 };
120 
121 /////////////////////////////////////////////////////////////////////////////////////////////////
122 
123 }  // namespace warp
124 }  // namespace gemm
125 }  // namespace cutlass
126 
127 /////////////////////////////////////////////////////////////////////////////////////////////////
128