xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Template for a double-buffered threadblock-scoped GEMM kernel.
34 */
35 
36 #pragma once
37 
38 #include <cutlass/aligned_buffer.h>
39 #include <cutlass/arch/memory.h>
40 #include <cutlass/array.h>
41 #include <cutlass/cutlass.h>
42 #include <cutlass/gemm/gemm.h>
43 #include <cutlass/gemm/threadblock/mma_base.h>
44 #include <cutlass/matrix_shape.h>
45 #include <cutlass/numeric_types.h>
46 
47 ////////////////////////////////////////////////////////////////////////////////
48 
49 namespace cutlass {
50 namespace gemm {
51 namespace threadblock {
52 
53 ////////////////////////////////////////////////////////////////////////////////
54 
55 /// Structure to compute the matrix product targeting CUDA cores and SIMT math
56 /// instructions.
57 template <
58     /// Size of the Gemm problem - concept: gemm::GemmShape<>
59     typename Shape_,
60     /// Policy describing tuning details (concept: MmaPolicy)
61     typename Policy_,
62     /// Number of stages,
63     int Stages,
64     /// Used for partial specialization
65     typename Enable = bool>
66 class CustomMmaBase {
67  public:
68   ///< Size of the Gemm problem - concept: gemm::GemmShape<>
69   using Shape = Shape_;
70 
71   ///< Policy describing tuning details
72   using Policy = Policy_;
73 
74   //
75   // Dependent types
76   //
77 
78   /// Warp-level Mma
79   using Operator = typename Policy::Operator;
80 
81   /// Shape describing the overall GEMM computed from shared memory
82   /// by each warp.
83   using WarpGemm = typename Policy::Operator::Shape;
84 
85   /// Shape describing the number of warps filling the CTA
86   using WarpCount = GemmShape<
87       Shape::kM / WarpGemm::kM,
88       Shape::kN / WarpGemm::kN,
89       Shape::kK / WarpGemm::kK>;
90 
91   /// Number of warp-level GEMM oeprations
92   static int const kWarpGemmIterations =
93       (WarpGemm::kK / Operator::Policy::MmaShape::kK);
94 
95   /// Number of stages
96   static int const kStages = Stages;
97 
98   //
99   // Nested structs
100   //
101 
102   /// Shared storage object needed by threadblock-scoped GEMM
103   template <typename Element, typename OperandShape, typename OperandLayout>
104   struct OperandSharedStorage {
105     AlignedBuffer<Element, OperandShape::kCount> buffer;
106     using TensorRef = TensorRef<Element, OperandLayout>;
107 
108     CUTLASS_DEVICE
LayoutOperandSharedStorage109     static OperandLayout Layout() {
110       return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
111     }
112 
113     /// Returns a TensorRef to the operand
114     CUTLASS_HOST_DEVICE
refOperandSharedStorage115     TensorRef ref() {
116       return TensorRef{buffer.data(), Layout()};
117     }
118   };
119 
120   /// Shape of the A matrix operand in shared memory
121   using ShapeA = MatrixShape<
122       Shape::kM + Policy::SmemPaddingA::kRow,
123       Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
124 
125   /// Shape of the B matrix operand in shared memory
126   using ShapeB = MatrixShape<
127       Shape::kK * kStages + Policy::SmemPaddingB::kRow,
128       Shape::kN + Policy::SmemPaddingB::kColumn>;
129 
130   using SharedStorageA = OperandSharedStorage<
131       typename Operator::ElementA,
132       ShapeA,
133       typename Operator::LayoutA>;
134   using SharedStorageB = OperandSharedStorage<
135       typename Operator::ElementB,
136       ShapeB,
137       typename Operator::LayoutB>;
138   using TensorRefA = typename SharedStorageA::TensorRef;
139   using TensorRefB = typename SharedStorageB::TensorRef;
140 
141   struct SharedStorage {
142     /// Buffer for A operand
143     SharedStorageA operand_A;
144 
145     /// Buffer for B operand
146     SharedStorageB operand_B;
147   };
148 
149  protected:
150   //
151   // Data members
152   //
153 
154   /// Iterator to load a warp-scoped tile of A operand from shared memory
155   typename Operator::IteratorA warp_tile_iterator_A_;
156 
157   /// Iterator to load a warp-scoped tile of B operand from shared memory
158   typename Operator::IteratorB warp_tile_iterator_B_;
159 
160  public:
161   /// Construct from tensor references
162   CUTLASS_DEVICE
CustomMmaBase(SharedStorageA & shared_storageA,SharedStorageB & shared_storageB,int thread_idx,int warp_idx,int lane_idx)163   CustomMmaBase(
164       ///< Shared storage needed for internal use by threadblock-scoped GEMM
165       SharedStorageA& shared_storageA,
166       SharedStorageB& shared_storageB,
167       ///< ID within the threadblock
168       int thread_idx,
169       ///< ID of warp
170       int warp_idx,
171       ///< ID of each thread within a warp
172       int lane_idx)
173       : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
174         warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {}
175 };
176 
177 /////////////////////////////////////////////////////////////////////////////////////////////////
178 
179 } // namespace threadblock
180 } // namespace gemm
181 } // namespace cutlass
182 
183 /////////////////////////////////////////////////////////////////////////////////////////////////
184