xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h>
11 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h>
12 
13 #include <cutlass/gemm/threadblock/mma_multistage.h>
14 #include <cutlass/gemm/threadblock/mma_pipelined.h>
15 template <typename Mma, int kMaxK>
16 struct MakeCustomMma;
17 
18 template <
19     typename Shape,
20     typename IteratorA,
21     typename SmemIteratorA,
22     cutlass::arch::CacheOperation::Kind CacheOpA,
23     typename IteratorB,
24     typename SmemIteratorB,
25     cutlass::arch::CacheOperation::Kind CacheOpB,
26     typename ElementC,
27     typename LayoutC,
28     typename Policy,
29     int Stages,
30     cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
31     int kMaxK>
32 struct MakeCustomMma<
33     cutlass::gemm::threadblock::MmaMultistage<
34         Shape,
35         IteratorA,
36         SmemIteratorA,
37         CacheOpA,
38         IteratorB,
39         SmemIteratorB,
40         CacheOpB,
41         ElementC,
42         LayoutC,
43         Policy,
44         Stages,
45         SharedMemoryClear>,
46     kMaxK> {
47   // Reduce the number of stages if we don't need that many
48   static int constexpr kStages =
49       kMaxK == cutlass::platform::numeric_limits<int>::max()
50       ? Stages
51       : cutlass::const_min(
52             Stages,
53             (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
54   using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<
55       Shape,
56       IteratorA,
57       SmemIteratorA,
58       CacheOpA,
59       IteratorB,
60       SmemIteratorB,
61       CacheOpB,
62       ElementC,
63       LayoutC,
64       Policy,
65       kStages,
66       SharedMemoryClear,
67       kMaxK>;
68 };
69 
70 template <
71     typename Shape,
72     typename IteratorA,
73     typename SmemIteratorA,
74     typename IteratorB,
75     typename SmemIteratorB,
76     typename ElementC,
77     typename LayoutC,
78     typename Policy,
79     int kMaxK>
80 struct MakeCustomMma<
81     cutlass::gemm::threadblock::MmaPipelined<
82         Shape,
83         IteratorA,
84         SmemIteratorA,
85         IteratorB,
86         SmemIteratorB,
87         ElementC,
88         LayoutC,
89         Policy>,
90     kMaxK> {
91   using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<
92       Shape,
93       IteratorA,
94       SmemIteratorA,
95       IteratorB,
96       SmemIteratorB,
97       ElementC,
98       LayoutC,
99       Policy>;
100 };
101