1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 // This file is auto-generated. See "generate_kernels.py"
9 #pragma once
10 #include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
11 using namespace PyTorchMemEffAttention;
12 // ======== f16 / sm70 ========
13 __global__ void __launch_bounds__(
14 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32, true>::kNumThreads,
15 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32, true>::kMinBlocksPerSm)
16 fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32, true>::Params p);
17 __global__ void __launch_bounds__(
18 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32>::kNumThreads,
19 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
20 fmha_cutlassB_f16_aligned_64x64_k32_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32>::Params p);
21 __global__ void __launch_bounds__(
22 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64, true>::kNumThreads,
23 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64, true>::kMinBlocksPerSm)
24 fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64, true>::Params p);
25 __global__ void __launch_bounds__(
26 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64>::kNumThreads,
27 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
28 fmha_cutlassB_f16_aligned_64x64_k64_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64>::Params p);
29 __global__ void __launch_bounds__(
30 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128, true>::kNumThreads,
31 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128, true>::kMinBlocksPerSm)
32 fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128, true>::Params p);
33 __global__ void __launch_bounds__(
34 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128>::kNumThreads,
35 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128>::kMinBlocksPerSm)
36 fmha_cutlassB_f16_aligned_128x64_k128_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128>::Params p);
37 __global__ void __launch_bounds__(
38 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128, true>::kNumThreads,
39 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128, true>::kMinBlocksPerSm)
40 fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128, true>::Params p);
41 __global__ void __launch_bounds__(
42 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128>::kNumThreads,
43 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
44 fmha_cutlassB_f16_aligned_64x64_k128_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128>::Params p);
45 __global__ void __launch_bounds__(
46 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 65536>::kNumThreads,
47 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 65536>::kMinBlocksPerSm)
48 fmha_cutlassB_f16_aligned_128x64_k65536_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 65536>::Params p);
49 __global__ void __launch_bounds__(
50 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 65536>::kNumThreads,
51 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
52 fmha_cutlassB_f16_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 65536>::Params p);
53 __global__ void __launch_bounds__(
54 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 32>::kNumThreads,
55 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
56 fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 32>::Params p);
57 __global__ void __launch_bounds__(
58 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 64>::kNumThreads,
59 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
60 fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 64>::Params p);
61 __global__ void __launch_bounds__(
62 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 128>::kNumThreads,
63 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 128>::kMinBlocksPerSm)
64 fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 128>::Params p);
65 __global__ void __launch_bounds__(
66 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 128>::kNumThreads,
67 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
68 fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 128>::Params p);
69 __global__ void __launch_bounds__(
70 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 65536>::kNumThreads,
71 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 65536>::kMinBlocksPerSm)
72 fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 65536>::Params p);
73 __global__ void __launch_bounds__(
74 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 65536>::kNumThreads,
75 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
76 fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 65536>::Params p);
77 __global__ void __launch_bounds__(
78 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 32>::kNumThreads,
79 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 32>::kMinBlocksPerSm)
80 fmha_cutlassB_f16_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 32>::Params p);
81 __global__ void __launch_bounds__(
82 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 64>::kNumThreads,
83 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 64>::kMinBlocksPerSm)
84 fmha_cutlassB_f16_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 64>::Params p);
85 __global__ void __launch_bounds__(
86 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 128>::kNumThreads,
87 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 128>::kMinBlocksPerSm)
88 fmha_cutlassB_f16_notaligned_128x64_k128_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 128>::Params p);
89 __global__ void __launch_bounds__(
90 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 128>::kNumThreads,
91 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 128>::kMinBlocksPerSm)
92 fmha_cutlassB_f16_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 128>::Params p);
93 __global__ void __launch_bounds__(
94 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 65536>::kNumThreads,
95 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 65536>::kMinBlocksPerSm)
96 fmha_cutlassB_f16_notaligned_128x64_k65536_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 65536>::Params p);
97 __global__ void __launch_bounds__(
98 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 65536>::kNumThreads,
99 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 65536>::kMinBlocksPerSm)
100 fmha_cutlassB_f16_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 65536>::Params p);
101 __global__ void __launch_bounds__(
102 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 32>::kNumThreads,
103 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 32>::kMinBlocksPerSm)
104 fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 32>::Params p);
105 __global__ void __launch_bounds__(
106 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 64>::kNumThreads,
107 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 64>::kMinBlocksPerSm)
108 fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 64>::Params p);
109 __global__ void __launch_bounds__(
110 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 128>::kNumThreads,
111 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 128>::kMinBlocksPerSm)
112 fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 128>::Params p);
113 __global__ void __launch_bounds__(
114 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 128>::kNumThreads,
115 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 128>::kMinBlocksPerSm)
116 fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 128>::Params p);
117 __global__ void __launch_bounds__(
118 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 65536>::kNumThreads,
119 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 65536>::kMinBlocksPerSm)
120 fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 65536>::Params p);
121 __global__ void __launch_bounds__(
122 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 65536>::kNumThreads,
123 AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
124 fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 65536>::Params p);
125
dispatch_cutlassB_f16_sm70(T cb,int cc)126 template <typename T> void dispatch_cutlassB_f16_sm70(T cb, int cc) {
127 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32, true>(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm70);
128 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm70);
129 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64, true>(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm70);
130 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm70);
131 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128, true>(), fmha_cutlassB_f16_aligned_128x64_k128_seqaligned_sm70);
132 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 128>(), fmha_cutlassB_f16_aligned_128x64_k128_sm70);
133 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128, true>(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm70);
134 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_sm70);
135 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_sm70);
136 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_sm70);
137 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm70);
138 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm70);
139 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 128>(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm70);
140 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm70);
141 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm70);
142 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm70);
143 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 32>(), fmha_cutlassB_f16_notaligned_64x64_k32_sm70);
144 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 64>(), fmha_cutlassB_f16_notaligned_64x64_k64_sm70);
145 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 128>(), fmha_cutlassB_f16_notaligned_128x64_k128_sm70);
146 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 128>(), fmha_cutlassB_f16_notaligned_64x64_k128_sm70);
147 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 128, 64, 65536>(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm70);
148 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm70);
149 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 32>(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm70);
150 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 64>(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm70);
151 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 128>(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm70);
152 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 128>(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm70);
153 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 128, 64, 65536>(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm70);
154 cb(AttentionBackwardKernel<cutlass::arch::Sm70, cutlass::half_t, false, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70);
155 }
156
157 // ======== bf16 / sm80 ========
158 __global__ void __launch_bounds__(
159 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32, true>::kNumThreads,
160 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32, true>::kMinBlocksPerSm)
161 fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32, true>::Params p);
162 __global__ void __launch_bounds__(
163 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32>::kNumThreads,
164 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32>::kMinBlocksPerSm)
165 fmha_cutlassB_bf16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32>::Params p);
166 __global__ void __launch_bounds__(
167 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64, true>::kNumThreads,
168 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64, true>::kMinBlocksPerSm)
169 fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64, true>::Params p);
170 __global__ void __launch_bounds__(
171 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>::kNumThreads,
172 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>::kMinBlocksPerSm)
173 fmha_cutlassB_bf16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>::Params p);
174 __global__ void __launch_bounds__(
175 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::kNumThreads,
176 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::kMinBlocksPerSm)
177 fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>::Params p);
178 __global__ void __launch_bounds__(
179 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128, true>::kNumThreads,
180 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128, true>::kMinBlocksPerSm)
181 fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128, true>::Params p);
182 __global__ void __launch_bounds__(
183 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>::kNumThreads,
184 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>::kMinBlocksPerSm)
185 fmha_cutlassB_bf16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>::Params p);
186 __global__ void __launch_bounds__(
187 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128, true>::kNumThreads,
188 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128, true>::kMinBlocksPerSm)
189 fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128, true>::Params p);
190 __global__ void __launch_bounds__(
191 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128>::kNumThreads,
192 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
193 fmha_cutlassB_bf16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128>::Params p);
194 __global__ void __launch_bounds__(
195 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 128, 64, 65536>::kNumThreads,
196 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 128, 64, 65536>::kMinBlocksPerSm)
197 fmha_cutlassB_bf16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 128, 64, 65536>::Params p);
198 __global__ void __launch_bounds__(
199 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 65536>::kNumThreads,
200 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
201 fmha_cutlassB_bf16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 65536>::Params p);
202 __global__ void __launch_bounds__(
203 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 32>::kNumThreads,
204 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 32>::kMinBlocksPerSm)
205 fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 32>::Params p);
206 __global__ void __launch_bounds__(
207 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 64>::kNumThreads,
208 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 64>::kMinBlocksPerSm)
209 fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 64>::Params p);
210 __global__ void __launch_bounds__(
211 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 128, 128, 128>::kNumThreads,
212 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 128, 128, 128>::kMinBlocksPerSm)
213 fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 128, 128, 128>::Params p);
214 __global__ void __launch_bounds__(
215 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 128>::kNumThreads,
216 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
217 fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 128>::Params p);
218 __global__ void __launch_bounds__(
219 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 128, 64, 65536>::kNumThreads,
220 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 128, 64, 65536>::kMinBlocksPerSm)
221 fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 128, 64, 65536>::Params p);
222 __global__ void __launch_bounds__(
223 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 65536>::kNumThreads,
224 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
225 fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 65536>::Params p);
226
dispatch_cutlassB_bf16_sm80(T cb,int cc)227 template <typename T> void dispatch_cutlassB_bf16_sm80(T cb, int cc) {
228 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32, true>(), fmha_cutlassB_bf16_aligned_64x64_k32_seqaligned_sm80);
229 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 32>(), fmha_cutlassB_bf16_aligned_64x64_k32_sm80);
230 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64, true>(), fmha_cutlassB_bf16_aligned_64x64_k64_seqaligned_sm80);
231 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64>(), fmha_cutlassB_bf16_aligned_64x64_k64_sm80);
232 if (cc == 86 || cc == 89) cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 64, 96>(), fmha_cutlassB_bf16_aligned_128x64_k96_sm80);
233 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128, true>(), fmha_cutlassB_bf16_aligned_128x128_k128_seqaligned_sm80);
234 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 128, 128, 128>(), fmha_cutlassB_bf16_aligned_128x128_k128_sm80);
235 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128, true>(), fmha_cutlassB_bf16_aligned_64x64_k128_seqaligned_sm80);
236 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_bf16_aligned_64x64_k128_sm80);
237 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 128, 64, 65536>(), fmha_cutlassB_bf16_aligned_128x64_k65536_sm80);
238 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, false, 64, 64, 65536>(), fmha_cutlassB_bf16_aligned_64x64_k65536_sm80);
239 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 32>(), fmha_cutlassB_bf16_aligned_64x64_k32_dropout_sm80);
240 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 64, 64, 64>(), fmha_cutlassB_bf16_aligned_64x64_k64_dropout_sm80);
241 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, true, 128, 128, 128>(), fmha_cutlassB_bf16_aligned_128x128_k128_dropout_sm80);
242 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 128>(), fmha_cutlassB_bf16_aligned_64x64_k128_dropout_sm80);
243 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 128, 64, 65536>(), fmha_cutlassB_bf16_aligned_128x64_k65536_dropout_sm80);
244 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, true, false, 64, 64, 65536>(), fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80);
245 }
246
247 // ======== f16 / sm80 ========
248 __global__ void __launch_bounds__(
249 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32, true>::kNumThreads,
250 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32, true>::kMinBlocksPerSm)
251 fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32, true>::Params p);
252 __global__ void __launch_bounds__(
253 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32>::kNumThreads,
254 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32>::kMinBlocksPerSm)
255 fmha_cutlassB_f16_aligned_64x64_k32_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32>::Params p);
256 __global__ void __launch_bounds__(
257 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64, true>::kNumThreads,
258 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64, true>::kMinBlocksPerSm)
259 fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64, true>::Params p);
260 __global__ void __launch_bounds__(
261 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>::kNumThreads,
262 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>::kMinBlocksPerSm)
263 fmha_cutlassB_f16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>::Params p);
264 __global__ void __launch_bounds__(
265 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::kNumThreads,
266 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::kMinBlocksPerSm)
267 fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>::Params p);
268 __global__ void __launch_bounds__(
269 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128, true>::kNumThreads,
270 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128, true>::kMinBlocksPerSm)
271 fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128, true>::Params p);
272 __global__ void __launch_bounds__(
273 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>::kNumThreads,
274 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>::kMinBlocksPerSm)
275 fmha_cutlassB_f16_aligned_128x128_k128_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>::Params p);
276 __global__ void __launch_bounds__(
277 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128, true>::kNumThreads,
278 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128, true>::kMinBlocksPerSm)
279 fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128, true>::Params p);
280 __global__ void __launch_bounds__(
281 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128>::kNumThreads,
282 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
283 fmha_cutlassB_f16_aligned_64x64_k128_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128>::Params p);
284 __global__ void __launch_bounds__(
285 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 128, 64, 65536>::kNumThreads,
286 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 128, 64, 65536>::kMinBlocksPerSm)
287 fmha_cutlassB_f16_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 128, 64, 65536>::Params p);
288 __global__ void __launch_bounds__(
289 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 65536>::kNumThreads,
290 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
291 fmha_cutlassB_f16_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 65536>::Params p);
292 __global__ void __launch_bounds__(
293 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 32>::kNumThreads,
294 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 32>::kMinBlocksPerSm)
295 fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 32>::Params p);
296 __global__ void __launch_bounds__(
297 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 64>::kNumThreads,
298 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 64>::kMinBlocksPerSm)
299 fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 64>::Params p);
300 __global__ void __launch_bounds__(
301 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 128, 128, 128>::kNumThreads,
302 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 128, 128, 128>::kMinBlocksPerSm)
303 fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 128, 128, 128>::Params p);
304 __global__ void __launch_bounds__(
305 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 128>::kNumThreads,
306 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
307 fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 128>::Params p);
308 __global__ void __launch_bounds__(
309 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 128, 64, 65536>::kNumThreads,
310 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 128, 64, 65536>::kMinBlocksPerSm)
311 fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 128, 64, 65536>::Params p);
312 __global__ void __launch_bounds__(
313 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 65536>::kNumThreads,
314 AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
315 fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 65536>::Params p);
316
dispatch_cutlassB_f16_sm80(T cb,int cc)317 template <typename T> void dispatch_cutlassB_f16_sm80(T cb, int cc) {
318 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32, true>(), fmha_cutlassB_f16_aligned_64x64_k32_seqaligned_sm80);
319 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm80);
320 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64, true>(), fmha_cutlassB_f16_aligned_64x64_k64_seqaligned_sm80);
321 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm80);
322 if (cc == 86 || cc == 89) cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 64, 96>(), fmha_cutlassB_f16_aligned_128x64_k96_sm80);
323 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128, true>(), fmha_cutlassB_f16_aligned_128x128_k128_seqaligned_sm80);
324 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, true, 128, 128, 128>(), fmha_cutlassB_f16_aligned_128x128_k128_sm80);
325 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128, true>(), fmha_cutlassB_f16_aligned_64x64_k128_seqaligned_sm80);
326 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_sm80);
327 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_sm80);
328 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_sm80);
329 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm80);
330 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm80);
331 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, true, 128, 128, 128>(), fmha_cutlassB_f16_aligned_128x128_k128_dropout_sm80);
332 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm80);
333 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm80);
334 cb(AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::half_t, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80);
335 }
336
337 // ======== f16 / sm50 ========
338 __global__ void __launch_bounds__(
339 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 32>::kNumThreads,
340 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
341 fmha_cutlassB_f16_aligned_64x64_k32_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 32>::Params p);
342 __global__ void __launch_bounds__(
343 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 64>::kNumThreads,
344 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
345 fmha_cutlassB_f16_aligned_64x64_k64_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 64>::Params p);
346 __global__ void __launch_bounds__(
347 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 128>::kNumThreads,
348 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
349 fmha_cutlassB_f16_aligned_64x64_k128_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 128>::Params p);
350 __global__ void __launch_bounds__(
351 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 65536>::kNumThreads,
352 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
353 fmha_cutlassB_f16_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 65536>::Params p);
354 __global__ void __launch_bounds__(
355 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 32>::kNumThreads,
356 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
357 fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 32>::Params p);
358 __global__ void __launch_bounds__(
359 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 64>::kNumThreads,
360 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
361 fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 64>::Params p);
362 __global__ void __launch_bounds__(
363 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 128>::kNumThreads,
364 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
365 fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 128>::Params p);
366 __global__ void __launch_bounds__(
367 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 65536>::kNumThreads,
368 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
369 fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 65536>::Params p);
370 __global__ void __launch_bounds__(
371 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 32>::kNumThreads,
372 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 32>::kMinBlocksPerSm)
373 fmha_cutlassB_f16_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 32>::Params p);
374 __global__ void __launch_bounds__(
375 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 64>::kNumThreads,
376 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 64>::kMinBlocksPerSm)
377 fmha_cutlassB_f16_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 64>::Params p);
378 __global__ void __launch_bounds__(
379 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 128>::kNumThreads,
380 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 128>::kMinBlocksPerSm)
381 fmha_cutlassB_f16_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 128>::Params p);
382 __global__ void __launch_bounds__(
383 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 65536>::kNumThreads,
384 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 65536>::kMinBlocksPerSm)
385 fmha_cutlassB_f16_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 65536>::Params p);
386 __global__ void __launch_bounds__(
387 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 32>::kNumThreads,
388 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 32>::kMinBlocksPerSm)
389 fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 32>::Params p);
390 __global__ void __launch_bounds__(
391 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 64>::kNumThreads,
392 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 64>::kMinBlocksPerSm)
393 fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 64>::Params p);
394 __global__ void __launch_bounds__(
395 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 128>::kNumThreads,
396 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 128>::kMinBlocksPerSm)
397 fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 128>::Params p);
398 __global__ void __launch_bounds__(
399 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 65536>::kNumThreads,
400 AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
401 fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 65536>::Params p);
402
dispatch_cutlassB_f16_sm50(T cb,int cc)403 template <typename T> void dispatch_cutlassB_f16_sm50(T cb, int cc) {
404 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm50);
405 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm50);
406 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_sm50);
407 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_sm50);
408 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm50);
409 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm50);
410 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm50);
411 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm50);
412 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 32>(), fmha_cutlassB_f16_notaligned_64x64_k32_sm50);
413 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 64>(), fmha_cutlassB_f16_notaligned_64x64_k64_sm50);
414 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 128>(), fmha_cutlassB_f16_notaligned_64x64_k128_sm50);
415 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm50);
416 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 32>(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm50);
417 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 64>(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm50);
418 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 128>(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm50);
419 cb(AttentionBackwardKernel<cutlass::arch::Sm50, cutlass::half_t, false, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50);
420 }
421
422 // ======== f32 / sm50 ========
423 __global__ void __launch_bounds__(
424 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 32>::kNumThreads,
425 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
426 fmha_cutlassB_f32_aligned_64x64_k32_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 32>::Params p);
427 __global__ void __launch_bounds__(
428 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 64>::kNumThreads,
429 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
430 fmha_cutlassB_f32_aligned_64x64_k64_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 64>::Params p);
431 __global__ void __launch_bounds__(
432 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 128>::kNumThreads,
433 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
434 fmha_cutlassB_f32_aligned_64x64_k128_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 128>::Params p);
435 __global__ void __launch_bounds__(
436 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>::kNumThreads,
437 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
438 fmha_cutlassB_f32_aligned_64x64_k65536_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>::Params p);
439 #if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM)
440 __global__ void __launch_bounds__(
441 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::kNumThreads,
442 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::kMinBlocksPerSm)
443 fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>::Params p);
444 __global__ void __launch_bounds__(
445 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::kNumThreads,
446 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::kMinBlocksPerSm)
447 fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>::Params p);
448 #else
449 __global__ void __launch_bounds__(
450 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::kNumThreads,
451 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
452 fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>::Params p);
453 __global__ void __launch_bounds__(
454 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::kNumThreads,
455 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
456 fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>::Params p);
457 #endif
458 __global__ void __launch_bounds__(
459 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>::kNumThreads,
460 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
461 fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>::Params p);
462 __global__ void __launch_bounds__(
463 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 65536>::kNumThreads,
464 AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
465 fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 65536>::Params p);
466 __global__ void __launch_bounds__(
467 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 32>::kNumThreads,
468 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 32>::kMinBlocksPerSm)
469 fmha_cutlassB_f32_notaligned_64x64_k32_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 32>::Params p);
470 __global__ void __launch_bounds__(
471 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 64>::kNumThreads,
472 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 64>::kMinBlocksPerSm)
473 fmha_cutlassB_f32_notaligned_64x64_k64_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 64>::Params p);
474 __global__ void __launch_bounds__(
475 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 128>::kNumThreads,
476 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 128>::kMinBlocksPerSm)
477 fmha_cutlassB_f32_notaligned_64x64_k128_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 128>::Params p);
478 __global__ void __launch_bounds__(
479 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 65536>::kNumThreads,
480 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 65536>::kMinBlocksPerSm)
481 fmha_cutlassB_f32_notaligned_64x64_k65536_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 65536>::Params p);
482 __global__ void __launch_bounds__(
483 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 32>::kNumThreads,
484 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 32>::kMinBlocksPerSm)
485 fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 32>::Params p);
486 __global__ void __launch_bounds__(
487 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 64>::kNumThreads,
488 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 64>::kMinBlocksPerSm)
489 fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 64>::Params p);
490 __global__ void __launch_bounds__(
491 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 128>::kNumThreads,
492 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 128>::kMinBlocksPerSm)
493 fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 128>::Params p);
494 __global__ void __launch_bounds__(
495 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 65536>::kNumThreads,
496 AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
497 fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 65536>::Params p);
498
dispatch_cutlassB_f32_sm50(T cb,int cc)499 template <typename T> void dispatch_cutlassB_f32_sm50(T cb, int cc) {
500 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm50);
501 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm50);
502 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm50);
503 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_sm50);
504 #if defined(CUDA_VERSION) && CUDA_VERSION == 12040 && !defined(USE_ROCM)
505 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 32>(), fmha_cutlassB_f32_aligned_32x32_k32_dropout_sm50);
506 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 32, 32, 64>(), fmha_cutlassB_f32_aligned_32x32_k64_dropout_sm50);
507 #else
508 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm50);
509 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm50);
510 #endif
511 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm50);
512 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm50);
513 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_sm50);
514 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 64>(), fmha_cutlassB_f32_notaligned_64x64_k64_sm50);
515 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 128>(), fmha_cutlassB_f32_notaligned_64x64_k128_sm50);
516 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm50);
517 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm50);
518 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 64>(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm50);
519 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 128>(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm50);
520 cb(AttentionBackwardKernel<cutlass::arch::Sm50, float, false, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50);
521 }
522
523 // ======== f32 / sm70 ========
524 __global__ void __launch_bounds__(
525 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 32>::kNumThreads,
526 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
527 fmha_cutlassB_f32_aligned_64x64_k32_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 32>::Params p);
528 __global__ void __launch_bounds__(
529 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 64>::kNumThreads,
530 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
531 fmha_cutlassB_f32_aligned_64x64_k64_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 64>::Params p);
532 __global__ void __launch_bounds__(
533 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 128>::kNumThreads,
534 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
535 fmha_cutlassB_f32_aligned_64x64_k128_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 128>::Params p);
536 __global__ void __launch_bounds__(
537 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 65536>::kNumThreads,
538 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
539 fmha_cutlassB_f32_aligned_64x64_k65536_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 65536>::Params p);
540 __global__ void __launch_bounds__(
541 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 32>::kNumThreads,
542 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
543 fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 32>::Params p);
544 __global__ void __launch_bounds__(
545 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 64>::kNumThreads,
546 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
547 fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 64>::Params p);
548 __global__ void __launch_bounds__(
549 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 128>::kNumThreads,
550 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
551 fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 128>::Params p);
552 __global__ void __launch_bounds__(
553 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 65536>::kNumThreads,
554 AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
555 fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 65536>::Params p);
556 __global__ void __launch_bounds__(
557 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 32>::kNumThreads,
558 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 32>::kMinBlocksPerSm)
559 fmha_cutlassB_f32_notaligned_64x64_k32_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 32>::Params p);
560 __global__ void __launch_bounds__(
561 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 64>::kNumThreads,
562 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 64>::kMinBlocksPerSm)
563 fmha_cutlassB_f32_notaligned_64x64_k64_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 64>::Params p);
564 __global__ void __launch_bounds__(
565 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 128>::kNumThreads,
566 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 128>::kMinBlocksPerSm)
567 fmha_cutlassB_f32_notaligned_64x64_k128_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 128>::Params p);
568 __global__ void __launch_bounds__(
569 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 65536>::kNumThreads,
570 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 65536>::kMinBlocksPerSm)
571 fmha_cutlassB_f32_notaligned_64x64_k65536_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 65536>::Params p);
572 __global__ void __launch_bounds__(
573 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 32>::kNumThreads,
574 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 32>::kMinBlocksPerSm)
575 fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 32>::Params p);
576 __global__ void __launch_bounds__(
577 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 64>::kNumThreads,
578 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 64>::kMinBlocksPerSm)
579 fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 64>::Params p);
580 __global__ void __launch_bounds__(
581 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 128>::kNumThreads,
582 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 128>::kMinBlocksPerSm)
583 fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 128>::Params p);
584 __global__ void __launch_bounds__(
585 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 65536>::kNumThreads,
586 AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
587 fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 65536>::Params p);
588
dispatch_cutlassB_f32_sm70(T cb,int cc)589 template <typename T> void dispatch_cutlassB_f32_sm70(T cb, int cc) {
590 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm70);
591 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm70);
592 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm70);
593 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_sm70);
594 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm70);
595 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm70);
596 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm70);
597 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm70);
598 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_sm70);
599 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 64>(), fmha_cutlassB_f32_notaligned_64x64_k64_sm70);
600 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 128>(), fmha_cutlassB_f32_notaligned_64x64_k128_sm70);
601 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm70);
602 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm70);
603 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 64>(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm70);
604 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 128>(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm70);
605 cb(AttentionBackwardKernel<cutlass::arch::Sm70, float, false, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70);
606 }
607
608 // ======== f16 / sm75 ========
609 __global__ void __launch_bounds__(
610 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 32>::kNumThreads,
611 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
612 fmha_cutlassB_f16_aligned_64x64_k32_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 32>::Params p);
613 __global__ void __launch_bounds__(
614 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 64>::kNumThreads,
615 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
616 fmha_cutlassB_f16_aligned_64x64_k64_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 64>::Params p);
617 __global__ void __launch_bounds__(
618 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 128>::kNumThreads,
619 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 128>::kMinBlocksPerSm)
620 fmha_cutlassB_f16_aligned_128x64_k128_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 128>::Params p);
621 __global__ void __launch_bounds__(
622 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 128>::kNumThreads,
623 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
624 fmha_cutlassB_f16_aligned_64x64_k128_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 128>::Params p);
625 __global__ void __launch_bounds__(
626 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 65536>::kNumThreads,
627 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 65536>::kMinBlocksPerSm)
628 fmha_cutlassB_f16_aligned_128x64_k65536_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 65536>::Params p);
629 __global__ void __launch_bounds__(
630 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 65536>::kNumThreads,
631 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
632 fmha_cutlassB_f16_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 65536>::Params p);
633 __global__ void __launch_bounds__(
634 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 32>::kNumThreads,
635 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
636 fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 32>::Params p);
637 __global__ void __launch_bounds__(
638 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 64>::kNumThreads,
639 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
640 fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 64>::Params p);
641 __global__ void __launch_bounds__(
642 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 128>::kNumThreads,
643 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 128>::kMinBlocksPerSm)
644 fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 128>::Params p);
645 __global__ void __launch_bounds__(
646 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 128>::kNumThreads,
647 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
648 fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 128>::Params p);
649 __global__ void __launch_bounds__(
650 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 65536>::kNumThreads,
651 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 65536>::kMinBlocksPerSm)
652 fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 65536>::Params p);
653 __global__ void __launch_bounds__(
654 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 65536>::kNumThreads,
655 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
656 fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 65536>::Params p);
657 __global__ void __launch_bounds__(
658 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 32>::kNumThreads,
659 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 32>::kMinBlocksPerSm)
660 fmha_cutlassB_f16_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 32>::Params p);
661 __global__ void __launch_bounds__(
662 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 64>::kNumThreads,
663 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 64>::kMinBlocksPerSm)
664 fmha_cutlassB_f16_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 64>::Params p);
665 __global__ void __launch_bounds__(
666 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 128>::kNumThreads,
667 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 128>::kMinBlocksPerSm)
668 fmha_cutlassB_f16_notaligned_128x64_k128_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 128>::Params p);
669 __global__ void __launch_bounds__(
670 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 128>::kNumThreads,
671 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 128>::kMinBlocksPerSm)
672 fmha_cutlassB_f16_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 128>::Params p);
673 __global__ void __launch_bounds__(
674 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 65536>::kNumThreads,
675 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 65536>::kMinBlocksPerSm)
676 fmha_cutlassB_f16_notaligned_128x64_k65536_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 65536>::Params p);
677 __global__ void __launch_bounds__(
678 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 65536>::kNumThreads,
679 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 65536>::kMinBlocksPerSm)
680 fmha_cutlassB_f16_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 65536>::Params p);
681 __global__ void __launch_bounds__(
682 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 32>::kNumThreads,
683 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 32>::kMinBlocksPerSm)
684 fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 32>::Params p);
685 __global__ void __launch_bounds__(
686 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 64>::kNumThreads,
687 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 64>::kMinBlocksPerSm)
688 fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 64>::Params p);
689 __global__ void __launch_bounds__(
690 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 128>::kNumThreads,
691 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 128>::kMinBlocksPerSm)
692 fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 128>::Params p);
693 __global__ void __launch_bounds__(
694 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 128>::kNumThreads,
695 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 128>::kMinBlocksPerSm)
696 fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 128>::Params p);
697 __global__ void __launch_bounds__(
698 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 65536>::kNumThreads,
699 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 65536>::kMinBlocksPerSm)
700 fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 65536>::Params p);
701 __global__ void __launch_bounds__(
702 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 65536>::kNumThreads,
703 AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
704 fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 65536>::Params p);
705
dispatch_cutlassB_f16_sm75(T cb,int cc)706 template <typename T> void dispatch_cutlassB_f16_sm75(T cb, int cc) {
707 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_sm75);
708 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_sm75);
709 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 128>(), fmha_cutlassB_f16_aligned_128x64_k128_sm75);
710 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_sm75);
711 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_sm75);
712 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_sm75);
713 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 32>(), fmha_cutlassB_f16_aligned_64x64_k32_dropout_sm75);
714 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 64>(), fmha_cutlassB_f16_aligned_64x64_k64_dropout_sm75);
715 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 128>(), fmha_cutlassB_f16_aligned_128x64_k128_dropout_sm75);
716 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 128>(), fmha_cutlassB_f16_aligned_64x64_k128_dropout_sm75);
717 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 128, 64, 65536>(), fmha_cutlassB_f16_aligned_128x64_k65536_dropout_sm75);
718 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm75);
719 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 32>(), fmha_cutlassB_f16_notaligned_64x64_k32_sm75);
720 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 64>(), fmha_cutlassB_f16_notaligned_64x64_k64_sm75);
721 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 128>(), fmha_cutlassB_f16_notaligned_128x64_k128_sm75);
722 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 128>(), fmha_cutlassB_f16_notaligned_64x64_k128_sm75);
723 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 128, 64, 65536>(), fmha_cutlassB_f16_notaligned_128x64_k65536_sm75);
724 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, false, false, 64, 64, 65536>(), fmha_cutlassB_f16_notaligned_64x64_k65536_sm75);
725 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 32>(), fmha_cutlassB_f16_notaligned_64x64_k32_dropout_sm75);
726 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 64>(), fmha_cutlassB_f16_notaligned_64x64_k64_dropout_sm75);
727 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 128>(), fmha_cutlassB_f16_notaligned_128x64_k128_dropout_sm75);
728 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 128>(), fmha_cutlassB_f16_notaligned_64x64_k128_dropout_sm75);
729 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 128, 64, 65536>(), fmha_cutlassB_f16_notaligned_128x64_k65536_dropout_sm75);
730 cb(AttentionBackwardKernel<cutlass::arch::Sm75, cutlass::half_t, false, true, false, 64, 64, 65536>(), fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75);
731 }
732
733 // ======== f32 / sm75 ========
734 __global__ void __launch_bounds__(
735 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 32>::kNumThreads,
736 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
737 fmha_cutlassB_f32_aligned_64x64_k32_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 32>::Params p);
738 __global__ void __launch_bounds__(
739 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 64>::kNumThreads,
740 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
741 fmha_cutlassB_f32_aligned_64x64_k64_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 64>::Params p);
742 __global__ void __launch_bounds__(
743 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 128>::kNumThreads,
744 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
745 fmha_cutlassB_f32_aligned_64x64_k128_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 128>::Params p);
746 __global__ void __launch_bounds__(
747 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 65536>::kNumThreads,
748 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
749 fmha_cutlassB_f32_aligned_64x64_k65536_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 65536>::Params p);
750 __global__ void __launch_bounds__(
751 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 32>::kNumThreads,
752 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
753 fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 32>::Params p);
754 __global__ void __launch_bounds__(
755 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 64>::kNumThreads,
756 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
757 fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 64>::Params p);
758 __global__ void __launch_bounds__(
759 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 128>::kNumThreads,
760 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
761 fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 128>::Params p);
762 __global__ void __launch_bounds__(
763 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 65536>::kNumThreads,
764 AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
765 fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 65536>::Params p);
766 __global__ void __launch_bounds__(
767 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 32>::kNumThreads,
768 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 32>::kMinBlocksPerSm)
769 fmha_cutlassB_f32_notaligned_64x64_k32_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 32>::Params p);
770 __global__ void __launch_bounds__(
771 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 64>::kNumThreads,
772 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 64>::kMinBlocksPerSm)
773 fmha_cutlassB_f32_notaligned_64x64_k64_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 64>::Params p);
774 __global__ void __launch_bounds__(
775 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 128>::kNumThreads,
776 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 128>::kMinBlocksPerSm)
777 fmha_cutlassB_f32_notaligned_64x64_k128_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 128>::Params p);
778 __global__ void __launch_bounds__(
779 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 65536>::kNumThreads,
780 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 65536>::kMinBlocksPerSm)
781 fmha_cutlassB_f32_notaligned_64x64_k65536_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 65536>::Params p);
782 __global__ void __launch_bounds__(
783 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 32>::kNumThreads,
784 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 32>::kMinBlocksPerSm)
785 fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 32>::Params p);
786 __global__ void __launch_bounds__(
787 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 64>::kNumThreads,
788 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 64>::kMinBlocksPerSm)
789 fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 64>::Params p);
790 __global__ void __launch_bounds__(
791 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 128>::kNumThreads,
792 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 128>::kMinBlocksPerSm)
793 fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 128>::Params p);
794 __global__ void __launch_bounds__(
795 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 65536>::kNumThreads,
796 AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 65536>::kMinBlocksPerSm)
797 fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 65536>::Params p);
798
dispatch_cutlassB_f32_sm75(T cb,int cc)799 template <typename T> void dispatch_cutlassB_f32_sm75(T cb, int cc) {
800 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm75);
801 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm75);
802 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm75);
803 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_sm75);
804 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm75);
805 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm75);
806 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm75);
807 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm75);
808 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_sm75);
809 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 64>(), fmha_cutlassB_f32_notaligned_64x64_k64_sm75);
810 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 128>(), fmha_cutlassB_f32_notaligned_64x64_k128_sm75);
811 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_notaligned_64x64_k65536_sm75);
812 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 32>(), fmha_cutlassB_f32_notaligned_64x64_k32_dropout_sm75);
813 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 64>(), fmha_cutlassB_f32_notaligned_64x64_k64_dropout_sm75);
814 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 128>(), fmha_cutlassB_f32_notaligned_64x64_k128_dropout_sm75);
815 cb(AttentionBackwardKernel<cutlass::arch::Sm75, float, false, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75);
816 }
817
818 // ======== f32 / sm80 ========
819 __global__ void __launch_bounds__(
820 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 32>::kNumThreads,
821 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 32>::kMinBlocksPerSm)
822 fmha_cutlassB_f32_aligned_64x64_k32_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 32>::Params p);
823 __global__ void __launch_bounds__(
824 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 64>::kNumThreads,
825 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 64>::kMinBlocksPerSm)
826 fmha_cutlassB_f32_aligned_64x64_k64_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 64>::Params p);
827 __global__ void __launch_bounds__(
828 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 128>::kNumThreads,
829 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 128>::kMinBlocksPerSm)
830 fmha_cutlassB_f32_aligned_128x64_k128_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 128>::Params p);
831 __global__ void __launch_bounds__(
832 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 128>::kNumThreads,
833 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 128>::kMinBlocksPerSm)
834 fmha_cutlassB_f32_aligned_64x64_k128_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 128>::Params p);
835 __global__ void __launch_bounds__(
836 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 65536>::kNumThreads,
837 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 65536>::kMinBlocksPerSm)
838 fmha_cutlassB_f32_aligned_128x64_k65536_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 65536>::Params p);
839 __global__ void __launch_bounds__(
840 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 65536>::kNumThreads,
841 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 65536>::kMinBlocksPerSm)
842 fmha_cutlassB_f32_aligned_64x64_k65536_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 65536>::Params p);
843 __global__ void __launch_bounds__(
844 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 32>::kNumThreads,
845 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 32>::kMinBlocksPerSm)
846 fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 32>::Params p);
847 __global__ void __launch_bounds__(
848 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 64>::kNumThreads,
849 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 64>::kMinBlocksPerSm)
850 fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 64>::Params p);
851 __global__ void __launch_bounds__(
852 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 128>::kNumThreads,
853 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 128>::kMinBlocksPerSm)
854 fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 128>::Params p);
855 __global__ void __launch_bounds__(
856 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 128>::kNumThreads,
857 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 128>::kMinBlocksPerSm)
858 fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 128>::Params p);
859 __global__ void __launch_bounds__(
860 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 65536>::kNumThreads,
861 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 65536>::kMinBlocksPerSm)
862 fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 65536>::Params p);
863 __global__ void __launch_bounds__(
864 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 65536>::kNumThreads,
865 AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 65536>::kMinBlocksPerSm)
866 fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 65536>::Params p);
867
dispatch_cutlassB_f32_sm80(T cb,int cc)868 template <typename T> void dispatch_cutlassB_f32_sm80(T cb, int cc) {
869 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_sm80);
870 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_sm80);
871 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 128>(), fmha_cutlassB_f32_aligned_128x64_k128_sm80);
872 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_sm80);
873 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 128, 64, 65536>(), fmha_cutlassB_f32_aligned_128x64_k65536_sm80);
874 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, false, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_sm80);
875 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 32>(), fmha_cutlassB_f32_aligned_64x64_k32_dropout_sm80);
876 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 64>(), fmha_cutlassB_f32_aligned_64x64_k64_dropout_sm80);
877 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 128>(), fmha_cutlassB_f32_aligned_128x64_k128_dropout_sm80);
878 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 128>(), fmha_cutlassB_f32_aligned_64x64_k128_dropout_sm80);
879 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 128, 64, 65536>(), fmha_cutlassB_f32_aligned_128x64_k65536_dropout_sm80);
880 cb(AttentionBackwardKernel<cutlass::arch::Sm80, float, true, true, false, 64, 64, 65536>(), fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80);
881 }
882
883
884 template <typename DT, typename T>
885 void dispatch_cutlassB(T cb, int cc = 0) {
886
887 if (std::is_same<DT, cutlass::half_t>::value && 70 <= cc && cc < 75) {
888 dispatch_cutlassB_f16_sm70(cb, cc);
889 }
890 if (std::is_same<DT, cutlass::bfloat16_t>::value && 80 <= cc && cc < 100) {
891 dispatch_cutlassB_bf16_sm80(cb, cc);
892 }
893 if (std::is_same<DT, cutlass::half_t>::value && 80 <= cc && cc < 100) {
894 dispatch_cutlassB_f16_sm80(cb, cc);
895 }
896 if (std::is_same<DT, cutlass::half_t>::value && 50 <= cc && cc < 70) {
897 dispatch_cutlassB_f16_sm50(cb, cc);
898 }
899 if (std::is_same<DT, float>::value && 50 <= cc && cc < 70) {
900 dispatch_cutlassB_f32_sm50(cb, cc);
901 }
902 if (std::is_same<DT, float>::value && 70 <= cc && cc < 75) {
903 dispatch_cutlassB_f32_sm70(cb, cc);
904 }
905 if (std::is_same<DT, cutlass::half_t>::value && 75 <= cc && cc < 80) {
906 dispatch_cutlassB_f16_sm75(cb, cc);
907 }
908 if (std::is_same<DT, float>::value && 75 <= cc && cc < 80) {
909 dispatch_cutlassB_f32_sm75(cb, cc);
910 }
911 if (std::is_same<DT, float>::value && 80 <= cc && cc < 100) {
912 dispatch_cutlassB_f32_sm80(cb, cc);
913 }
914 }
915