xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <cmath>
11 #include <type_traits>
12 #include <vector>
13 
14 #include <cuda_fp16.h>
15 #include <curand_kernel.h>
16 
17 #include <ATen/cuda/PhiloxUtils.cuh>
18 #include <cutlass/cutlass.h>
19 #include <cutlass/epilogue/thread/linear_combination.h>
20 #include <cutlass/epilogue/thread/scale_type.h>
21 #include <cutlass/fast_math.h>
22 #include <cutlass/functional.h>
23 #include <cutlass/gemm/gemm.h>
24 #include <cutlass/layout/matrix.h>
25 #include <cutlass/layout/vector.h>
26 #include <cutlass/numeric_conversion.h>
27 #include <cutlass/numeric_types.h>
28 #include <cutlass/tensor_ref.h>
29 
30 #include <cutlass/epilogue/thread/linear_combination_relu.h>
31 #include <cutlass/epilogue/threadblock/epilogue_smem_accumulator.h>
32 #include <cutlass/epilogue/warp/fragment_iterator_tensor_op.h>
33 #include <cutlass/epilogue/warp/tile_iterator_tensor_op.h>
34 #include <cutlass/gemm/device/default_gemm_configuration.h>
35 #include <cutlass/gemm/kernel/default_gemm.h>
36 #include <cutlass/gemm/threadblock/default_mma.h>
37 #include <cutlass/gemm/threadblock/default_mma_core_simt.h>
38 #include <cutlass/gemm/threadblock/default_mma_core_sm70.h>
39 #include <cutlass/gemm/threadblock/default_mma_core_sm75.h>
40 #include <cutlass/gemm/threadblock/default_mma_core_sm80.h>
41 #include <cutlass/integer_subbyte.h>
42 #include <cutlass/matrix_shape.h>
43 #include <cutlass/platform/platform.h>
44 #include <cutlass/transform/threadblock/predicated_tile_iterator.h>
45 #include <cutlass/transform/threadblock/vector_iterator.h>
46 
47 #include <ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h>
48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
49 
50 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h>
51 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h>
52 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h>
53 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h>
54 
55 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h>
56 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h>
57 #include <ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h>
58 
59 #include <cinttypes>
60 #include <c10/util/Exception.h>
61 
62 using namespace gemm_kernel_utils;
63 
64 namespace PyTorchMemEffAttention {
65 namespace {
66 
67 template <typename FragmentType, int32_t kNumThreads>
68 struct GmemTile {
69   /*
70     Helper functions to efficient store/load RF to gmem
71 
72     GEMM accumulators have a particular format on A100, and
73     it takes some compute/shared-memory to rearrange them to
74     a RowMajor or ColumnMajor format in global memory through
75     an Epilogue. The same complexity goes for loading into RF.
76 
77     This class loads/stores RF as they are, and can be used for
78     efficient accumulation across gemms for instance:
79 
80     ```
81     GmemTile tile;
82     for (int i = 0; i < N; ++i) {
83       // ...
84 
85       Fragment accum;
86       if (i == 0) {
87         accum.clear();
88       } else {
89         tile.load(accum);
90       }
91       mma(accum, ...);
92       if (i < N-1) {
93         // Store for next GEMM
94         tile.store(accum);
95       } else {
96         // Store in tensor (eg RowMajor)
97         epilogue(accum);
98       }
99 
100       // ...
101     }
102     ```
103   */
104 
105   // 128bits per thread
106   using AccessType = cutlass::Array<float, 4>;
107   static constexpr int32_t kBytes = sizeof(AccessType);
108   static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
109   static constexpr int32_t kNumIters =
110       FragmentType::kElements / AccessType::kElements;
111   static constexpr int32_t kElementsStored =
112       kNumThreads * FragmentType::kElements;
113   static_assert(
114       FragmentType::kElements % AccessType::kElements == 0,
115       "fragment not aligned on 128 bits");
116 
117   float* ptr;
118 
loadGmemTile119   CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) {
120     CUTLASS_PRAGMA_UNROLL
121     for (int i = 0; i < kNumIters; ++i) {
122       AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
123           ptr + thread_id * AccessType::kElements + i * kStride);
124       AccessType sub_fragment;
125       cutlass::arch::global_load<AccessType, kBytes>(
126           sub_fragment, gmem_ptr, true);
127       CUTLASS_PRAGMA_UNROLL
128       for (int j = 0; j < AccessType::kElements; ++j) {
129         fragment[i * AccessType::kElements + j] = sub_fragment[j];
130       }
131     }
132   }
133 
storeGmemTile134   CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) {
135     CUTLASS_PRAGMA_UNROLL
136     for (int i = 0; i < kNumIters; ++i) {
137       AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
138           ptr + thread_id * AccessType::kElements + i * kStride);
139       AccessType sub_fragment;
140       CUTLASS_PRAGMA_UNROLL
141       for (int j = 0; j < AccessType::kElements; ++j) {
142         sub_fragment[j] = fragment[i * AccessType::kElements + j];
143       }
144       cutlass::arch::global_store<AccessType, kBytes>(
145           sub_fragment, gmem_ptr, true);
146     }
147   }
148 
storeAtomicAddGmemTile149   CUTLASS_DEVICE void storeAtomicAdd(
150       FragmentType const& fragment,
151       int thread_id) {
152     CUTLASS_PRAGMA_UNROLL
153     for (int i = 0; i < kNumIters; ++i) {
154       float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride;
155       CUTLASS_PRAGMA_UNROLL
156       for (int j = 0; j < AccessType::kElements; ++j) {
157         float val = fragment[i * AccessType::kElements + j];
158         float* ptr = gmem_ptr + j;
159         atomicAdd(ptr, val);
160       }
161     }
162   }
163 };
164 
165 struct AtomicLock {
acquireAtomicLock166   CUTLASS_DEVICE static void acquire(
167       int32_t* lock,
168       int set_val,
169       int thread_id) {
170     if (thread_id == 0) {
171       while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) {
172 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
173         __nanosleep(40);
174 #endif
175       }
176     }
177     __syncthreads();
178   }
releaseAtomicLock179   CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) {
180     if (thread_id == 0) {
181       int status = 0;
182 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
183       asm volatile("st.global.release.gpu.b32 [%0], %1;\n"
184                    :
185                    : "l"(lock), "r"(status));
186 #else
187       asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
188 #endif
189     }
190   }
191 };
192 
193 template <typename scalar_t, typename Arch>
getWarpsPerSmBw()194 constexpr int getWarpsPerSmBw() {
195   bool is_half = !cutlass::platform::is_same<scalar_t, float>::value;
196   if (Arch::kMinComputeCapability >= 80) {
197     return is_half ? 12 : 8;
198   }
199   return 8;
200 }
201 } // namespace
202 
203 template <
204     // which arch we target (eg `cutlass::arch::Sm80`)
205     typename ArchTag_,
206     // input/output type
207     typename scalar_t_,
208     // run optimized kernel because memory accesses will be aligned
209     bool kIsAligned_,
210     // use dropout if enabled
211     bool kApplyDropout_,
212     // when doing a GEMM, preload the next one (uses more shmem)
213     bool kPreload_,
214     // block dimensions
215     int kBlockSizeI_,
216     int kBlockSizeJ_,
217     // upperbound on `max(value.shape[-1], query.shape[-1])`
218     int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
219     // assumes that `cu_seqlen` is None, and
220     // (1) `num_queries % kBlockSizeI == 0`
221     // (2) `num_keys % kBlockSizeJ == 0`
222     bool kKeysQueriesAlignedToBlockSize_ = false>
223 struct AttentionBackwardKernel {
224   enum CustomMaskType {
225     NoCustomMask = 0,
226     CausalFromTopLeft = 1,
227     CausalFromBottomRight = 2,
228     NumCustomMaskTypes,
229   };
230   using scalar_t = scalar_t_;
231   using output_t = scalar_t;
232   using output_accum_t = float;
233   using lse_scalar_t = float;
234   using accum_t = float;
235   using ArchTag = ArchTag_;
236   static constexpr bool kIsAligned = kIsAligned_;
237   static constexpr bool kApplyDropout = kApplyDropout_;
238   static constexpr bool kPreload = kPreload_;
239   static constexpr int kBlockSizeI = kBlockSizeI_;
240   static constexpr int kBlockSizeJ = kBlockSizeJ_;
241   static constexpr int kMaxK = kMaxK_;
242   static constexpr bool kKeysQueriesAlignedToBlockSize =
243       kKeysQueriesAlignedToBlockSize_;
244 
245   static constexpr int64_t kWarpSize = 32;
246 
247   // If this is true, we store and accumulate dK/dV in RF
248   // rather than going back to gmem everytime
249   static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
250   static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
251   static_assert(
252       !kPreload ||
253           (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF),
254       "preload MMA not supported");
255   static constexpr bool kPrologueQK = kPreload;
256   static constexpr bool kPrologueGV = kPreload;
257   static constexpr bool kPrologueDOV = kPreload;
258   static constexpr bool kPrologueGQ = kPreload;
259   static constexpr bool kPrologueGK = kPreload;
260 
261   static constexpr int64_t kNumWarpsPerBlock =
262       (kBlockSizeI * kBlockSizeJ) / (32 * 32);
263 
264   // Compute delta for the f16 kernels
265   // TODO: Figure out why it's slower on the f32 kernels
266   // (something due to RF pressure?)
267   // TODO: Remove condition on `kOutputInRF` - this is needed to work
268   // around a compiler bug on V100, not exactly sure why but I spent
269   // too much time on this already. Reproducible with
270   // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
271   static constexpr bool kKernelComputesDelta =
272       kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);
273 
274   // Launch bounds
275   static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
276   static constexpr int64_t kMinBlocksPerSm =
277       getWarpsPerSmBw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
278 
279   using GemmType = DefaultGemmType<ArchTag, scalar_t>;
280   using DefaultConfig =
281       typename cutlass::gemm::device::DefaultGemmConfiguration<
282           typename GemmType::OpClass,
283           ArchTag,
284           scalar_t,
285           scalar_t,
286           scalar_t, // ElementC
287           accum_t // ElementAccumulator
288           >;
289   static constexpr auto kOptimalAlignement = cutlass::platform::max(
290       DefaultConfig::kAlignmentA,
291       DefaultConfig::kAlignmentB);
292   static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment;
293 
294   struct MatmulQK {
295     /*
296     attn_T = k_j @ q_i.transpose(-2, -1) # matmul
297     attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2,
298     -1)).exp() # epilogue
299 
300     with attn_T.shape = (kBlockSizeJ, kBlockSizeI)
301     */
302     using ThreadblockShape =
303         cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
304     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
305     using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma<
306         scalar_t, // ElementA
307         cutlass::layout::RowMajor, // LayoutA
308         kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
309         scalar_t, // ElementB
310         cutlass::layout::ColumnMajor, // LayoutB
311         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
312         accum_t, // ElementC
313         cutlass::layout::RowMajor, // LayoutC
314         typename GemmType::OpClass,
315         ArchTag,
316         ThreadblockShape,
317         WarpShape,
318         typename GemmType::InstructionShape,
319         DefaultConfig::kStages,
320         typename GemmType::Operator,
321         false, // AccumulatorsInRowMajor = false,
322         cutlass::gemm::SharedMemoryClearOption::kNone>;
323     using MmaCore = typename DefaultMma::MmaCore;
324     using Mma =
325         typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma;
326 
327     // used for efficient load of bias tile (Bij) from global memory to shared
328     // memory
329     using BiasLoader = TileSmemLoader<
330         scalar_t,
331         // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
332         // row-major but needs to have transposed shape so we get the same
333         // elements.
334         cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>,
335         MmaCore::kThreads,
336         // input restriction: kv_len has to be a multiple of this value
337         128 / cutlass::sizeof_bits<scalar_t>::value>;
338 
339     // Epilogue to store to shared-memory in a format that we can use later for
340     // the second matmul
341     using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
342         typename Mma::Operator::IteratorC,
343         typename Mma::Operator,
344         scalar_t,
345         WarpShape,
346         ThreadblockShape>;
347     using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
348         typename Mma::Operator::IteratorC,
349         accum_t,
350         kWarpSize>::Iterator;
351     using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
352   };
353 
354   struct MatmulGradV {
355     /*
356     grad_v[j_start:j_end] += attn_T @ do_i # matmul
357 
358     Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K)
359     (we might need to iterate multiple times on K)
360     */
361     using ThreadblockShape =
362         cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
363     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
364     using InstructionShape = typename GemmType::InstructionShape;
365 
366     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
367         scalar_t, // ElementA,
368         cutlass::layout::RowMajor, // LayoutA,
369         DefaultConfig::kAlignmentA,
370         scalar_t, // ElementB,
371         cutlass::layout::RowMajor, // LayoutB,
372         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
373         output_t,
374         cutlass::layout::RowMajor, // LayoutC,
375         accum_t,
376         typename GemmType::OpClass,
377         ArchTag,
378         ThreadblockShape,
379         WarpShape,
380         typename GemmType::InstructionShape,
381         typename DefaultConfig::EpilogueOutputOp,
382         void, // ThreadblockSwizzle - not used
383         DefaultConfig::kStages,
384         false, // SplitKSerial
385         typename GemmType::Operator>;
386 
387     // if dropout:
388     //   for computing dVj += (Pij.T * Zij) @ dOi
389     //   Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
390     //   Pij.T are loaded in. The reason we do it this way is because Pij.T and
391     //   Zij are reused in later steps, while Pij_dropped.T is only needed in
392     //   this step. computing Pij_dropped.T on the fly allows us to avoid
393     //   keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
394     //   same time.
395     // if no dropout:
396     //   for computing dVj += Pij.T @ dOi
397     using WarpIteratorA = typename cutlass::gemm::threadblock::
398         DefaultWarpIteratorAFromSharedMemory<
399             typename DefaultGemm::Mma::Operator::Shape, // WarpShape
400             typename DefaultGemm::Mma::Operator::
401                 InstructionShape, // InstructionShape
402             typename DefaultGemm::Mma::Operator::
403                 IteratorA, // RegularWarpIterator
404             typename DefaultGemm::Mma::Policy // Policy
405             >::WarpIterator;
406     using DefaultMmaFromSmem =
407         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
408             typename DefaultGemm::Mma,
409             MatmulQK::AccumulatorSharedStorage::Shape::kN,
410             WarpIteratorA,
411             kApplyDropout>; // kScaleOperandA
412 
413     using Mma = typename DefaultMmaFromSmem::Mma;
414     using IteratorB = typename Mma::IteratorB;
415     using WarpCount = typename Mma::WarpCount;
416 
417     // Epilogue
418     using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
419     using DefaultEpilogue = typename DefaultGemm::Epilogue;
420     using OutputTileIterator =
421         typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
422             typename DefaultEpilogue::OutputTileIterator>::Iterator;
423     using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
424   };
425 
426   struct MatmulDOIVJ {
427     /*
428     doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul
429     tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue?
430     */
431     using ThreadblockShape =
432         cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
433     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
434 
435     using ElementC = output_t;
436     using ElementAccum = accum_t;
437 
438     // no-op output op - epilogue just stores result to global memory
439     using BiasGradEpilogueOutputOp =
440         typename cutlass::epilogue::thread::LinearCombination<
441             ElementC,
442             DefaultConfig::EpilogueOutputOp::kCount,
443             typename DefaultConfig::EpilogueOutputOp::ElementAccumulator,
444             typename DefaultConfig::EpilogueOutputOp::ElementCompute,
445             cutlass::epilogue::thread::ScaleType::Nothing>;
446 
447     using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
448         scalar_t, // ElementA
449         cutlass::layout::RowMajor, // LayoutA
450         kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
451         scalar_t, // ElementB
452         cutlass::layout::ColumnMajor, // LayoutB
453         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
454         ElementC, // ElementC
455         cutlass::layout::RowMajor, // LayoutC
456         ElementAccum, // ElementAccumulator
457         typename GemmType::OpClass,
458         ArchTag,
459         ThreadblockShape,
460         WarpShape,
461         typename GemmType::InstructionShape,
462         BiasGradEpilogueOutputOp, // EpilogueOutputOp
463         void, // ThreadblockSwizzle (not used)
464         // multiple preloads, dropout Zij tile, and 3 stages push us over shared
465         // memory capacity on A100. set a ceiling on number of stages to save
466         // shared memory if dropout is in use.
467         kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64)
468             ? cutlass::const_min(2, DefaultConfig::kStages)
469             : DefaultConfig::kStages, // Stages
470         false, // SplitKSerial
471         typename GemmType::Operator,
472         cutlass::gemm::SharedMemoryClearOption::kNone>;
473     using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
474     using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
475         typename Mma::Operator::IteratorC,
476         ElementAccum,
477         kWarpSize>::Iterator;
478 
479     // epilogue used to write bias gradient, which is just the output of this
480     // matmul with some operations applied to the fragment
481     using BiasGradEpilogue = typename DefaultGemm::Epilogue;
482 
483     // Epilogue to store to shared-memory in a format that we can use later for
484     // the second matmul
485     using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
486         typename DefaultGemm::Mma::Operator::IteratorC,
487         typename DefaultGemm::Mma::Operator,
488         scalar_t,
489         WarpShape,
490         ThreadblockShape>;
491     using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
492   };
493 
494   struct MatmulGradQ {
495     // grad_q <- tmp @ k_j
496     using ThreadblockShape =
497         cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
498     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
499     using InstructionShape = typename GemmType::InstructionShape;
500 
501     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
502         scalar_t, // ElementA,
503         cutlass::layout::RowMajor, // LayoutA,
504         DefaultConfig::kAlignmentA,
505         scalar_t, // ElementB,
506         cutlass::layout::RowMajor, // LayoutB,
507         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
508         output_t,
509         cutlass::layout::RowMajor, // LayoutC,
510         accum_t,
511         typename GemmType::OpClass,
512         ArchTag,
513         ThreadblockShape,
514         WarpShape,
515         typename GemmType::InstructionShape,
516         typename DefaultConfig::EpilogueOutputOp,
517         void, // ThreadblockSwizzle - not used
518         DefaultConfig::kStages,
519         false, // SplitKSerial
520         typename GemmType::Operator>;
521 
522     using WarpIteratorA = typename cutlass::gemm::threadblock::
523         DefaultWarpIteratorAFromSharedMemory<
524             typename DefaultGemm::Mma::Operator::Shape,
525             typename DefaultGemm::Mma::Operator::InstructionShape,
526             typename DefaultGemm::Mma::Operator::IteratorA,
527             typename DefaultGemm::Mma::Policy>::WarpIterator;
528     using DefaultMmaFromSmem =
529         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
530             typename DefaultGemm::Mma,
531             MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN,
532             WarpIteratorA,
533             false>; // kScaleOperandA
534     using Mma = typename DefaultMmaFromSmem::Mma;
535     using IteratorB = typename Mma::IteratorB;
536     using WarpCount = typename Mma::WarpCount;
537 
538     // Epilogue
539     using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
540     using DefaultEpilogue = typename DefaultGemm::Epilogue;
541     using OutputTileIterator =
542         typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
543             typename DefaultEpilogue::OutputTileIterator>::Iterator;
544     using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
545   };
546   struct MatmulGradK {
547     // grad_k <- tmp.transpose(-2, -1) @ q_i
548     using ThreadblockShape =
549         cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
550     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
551     using InstructionShape = typename GemmType::InstructionShape;
552 
553     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
554         scalar_t, // ElementA,
555         cutlass::layout::RowMajor, // LayoutA,
556         DefaultConfig::kAlignmentA,
557         scalar_t, // ElementB,
558         cutlass::layout::RowMajor, // LayoutB,
559         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
560         output_t,
561         cutlass::layout::RowMajor, // LayoutC,
562         accum_t,
563         typename GemmType::OpClass,
564         ArchTag,
565         ThreadblockShape,
566         WarpShape,
567         typename GemmType::InstructionShape,
568         typename DefaultConfig::EpilogueOutputOp,
569         void, // ThreadblockSwizzle - not used
570         DefaultConfig::kStages,
571         false, // SplitKSerial
572         typename GemmType::Operator>;
573 
574     using WarpIteratorA = typename cutlass::gemm::threadblock::
575         DefaultWarpIteratorAFromSharedMemory<
576             typename DefaultGemm::Mma::Operator::Shape,
577             typename DefaultGemm::Mma::Operator::InstructionShape,
578             typename DefaultGemm::Mma::Operator::IteratorA,
579             typename DefaultGemm::Mma::Policy>::WarpIterator;
580     using DefaultMmaFromSmemN =
581         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
582             typename DefaultGemm::Mma,
583             MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK
584             WarpIteratorA,
585             false>; // kScaleOperandA
586     using DefaultMmaFromSmemT =
587         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
588             typename DefaultGemm::Mma,
589             MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK
590             WarpIteratorA,
591             false, // kScaleOperandA
592             kPreload>; // kTransposeA
593     using DefaultMmaFromSmem = typename cutlass::platform::conditional<
594         DefaultMmaFromSmemT::kIsTransposedA,
595         DefaultMmaFromSmemT,
596         DefaultMmaFromSmemN>::type;
597     using Mma = typename DefaultMmaFromSmem::Mma;
598     using IteratorB = typename Mma::IteratorB;
599     using WarpCount = typename Mma::WarpCount;
600 
601     // Epilogue
602     using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
603     using DefaultEpilogue = typename DefaultGemm::Epilogue;
604     using OutputTileIterator =
605         typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
606             typename DefaultEpilogue::OutputTileIterator>::Iterator;
607     using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
608   };
609 
610   // NOTE: nvcc 12.4 has correctness errors with this on M60 (sm52)
611   // when there is an attention bias. Let's just disable it for now.
612   static constexpr auto kMinSm = ArchTag::kMinComputeCapability;
613   static constexpr bool kEnableSplitKeys = kMinSm >= 70;
614 
615   static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys ||
616       !cutlass::platform::is_same<output_accum_t, output_t>::value;
617   static constexpr bool kNeedsAccumGradK = !kOutputInRF &&
618       !cutlass::platform::is_same<output_accum_t, output_t>::value;
619   static constexpr bool kNeedsAccumGradV = !kOutputInRF &&
620       !cutlass::platform::is_same<output_accum_t, output_t>::value;
621 
622   struct GradQTempStorage {
623     int32_t lock;
624     int32_t counter;
625     int32_t pad[2]; // pad to 128bits
626     output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored];
627   };
628 
629   struct Params {
630     // Input tensors
631     const scalar_t* query_ptr = nullptr; // [Mq, nH, K]
632     const scalar_t* key_ptr = nullptr; // [Mk, nH, K]
633     const scalar_t* value_ptr = nullptr; // [Mk, nH, Kv]
634     const scalar_t* bias_ptr = nullptr;
635     const lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq]
636     const scalar_t* output_ptr = nullptr; // [Mq, nH, Kv]
637     const scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv]
638     accum_t* delta_ptr = nullptr; // [nH, Mq]
639     const int32_t* cu_seqlens_q_ptr = nullptr;
640     const int32_t* cu_seqlens_k_ptr = nullptr;
641 
642     // Output tensors
643     output_t* grad_query_ptr = nullptr; //  [Mq, nH, K]
644     output_t* grad_key_ptr = nullptr; //    [Mk, nH, K]
645     output_t* grad_value_ptr = nullptr; //  [Mk, nH, Kv]
646     output_t* grad_bias_ptr = nullptr;
647 
648     // Accumulators
649     output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
650     output_accum_t* workspace_gv =
651         nullptr; // (will be calculated by the kernel)
652     GradQTempStorage* workspace_gq =
653         nullptr; // (will be calculated by the kernel)
654 
655     // Sliding window. ignored if == 0
656     int32_t window_size = 0;
657 
658     // Scale
659     accum_t scale = 1.0f;
660 
661     // Dimensions/strides
662     int32_t head_dim = -1;
663     int32_t head_dim_value = -1;
664     int32_t num_queries = -1;
665     int32_t num_keys = -1;
666     int32_t num_heads = -1;
667     uint8_t custom_mask_type = NoCustomMask;
668 
669     int32_t q_strideM = -1;
670     int32_t k_strideM = -1;
671     int32_t v_strideM = -1;
672     int32_t bias_strideM = 0;
673     int32_t gO_strideM = -1;
674     int32_t gB_strideM = -1;
675     int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise
676 
677     at::PhiloxCudaState rng_engine_inputs = {0, 0};
678 
679     // RNG sequence offset based on batch_id and head_id
680     unsigned long long dropout_batch_head_rng_offset = 0;
681     float dropout_prob = 0.0f;
682 
o_strideMAttentionBackwardKernel::Params683     CUTLASS_HOST_DEVICE int32_t o_strideM() const {
684       return head_dim_value * num_heads;
685     }
gQ_strideMAttentionBackwardKernel::Params686     CUTLASS_HOST_DEVICE int32_t gQ_strideM() const {
687       return gQKV_strideM_multiplier * num_heads * head_dim;
688     }
gK_strideMAttentionBackwardKernel::Params689     CUTLASS_HOST_DEVICE int32_t gK_strideM() const {
690       return gQKV_strideM_multiplier * num_heads * head_dim;
691     }
gV_strideMAttentionBackwardKernel::Params692     CUTLASS_HOST_DEVICE int32_t gV_strideM() const {
693       return gQKV_strideM_multiplier * num_heads * head_dim_value;
694     }
695 
696     // Everything below is only used in `advance_to_block`
697     // and shouldn't use registers
698     int64_t o_strideH = -1;
699     int32_t q_strideH = -1;
700     int32_t k_strideH = -1;
701     int32_t v_strideH = -1;
702     int64_t bias_strideH = 0;
703     int64_t o_strideB = -1;
704     int64_t q_strideB = -1;
705     int64_t k_strideB = -1;
706     int64_t v_strideB = -1;
707     int64_t bias_strideB = 0;
708     int64_t lse_strideB = -1;
709     int64_t lse_strideH = -1;
710     int64_t delta_strideB = -1;
711     int64_t delta_strideH = -1;
712     int32_t num_batches = -1;
713     int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel
714 
715     int64_t gO_strideB = 0;
716     int64_t gQ_strideB = 0;
717     int64_t gK_strideB = 0;
718     int64_t gV_strideB = 0;
719     int64_t gB_strideB = 0;
720     int64_t gO_strideH = 0;
721     int64_t gQ_strideH = 0;
722     int64_t gK_strideH = 0;
723     int64_t gV_strideH = 0;
724     int64_t gB_strideH = 0;
725 
num_splits_key_deviceAttentionBackwardKernel::Params726     CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const {
727 #ifdef __CUDA_ARCH__
728       return kEnableSplitKeys ? gridDim.x : 1;
729 #else
730       return num_splits_key; // for host-side tests
731 #endif
732     }
split_key_deviceAttentionBackwardKernel::Params733     CUTLASS_HOST_DEVICE int16_t split_key_device() const {
734 #ifdef __CUDA_ARCH__
735       return kEnableSplitKeys ? blockIdx.x : 0;
736 #else
737       return 0; // for host-side tests
738 #endif
739     }
740 
advance_to_blockAttentionBackwardKernel::Params741     CUTLASS_DEVICE bool advance_to_block() {
742       int64_t batch_id = blockIdx.z;
743       int32_t head_id = blockIdx.y;
744 
745       if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
746         assert(workspace_size() == 0 || workspace != nullptr);
747 
748         workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
749         workspace = warp_uniform(workspace);
750         workspace_gv = workspace + workspace_elements_gk();
751         workspace_gq =
752             (GradQTempStorage*)(workspace_gv + workspace_elements_gv());
753         if (kEnableSplitKeys) {
754           workspace_gv += workspace_elements_gv() * split_key_device() /
755               num_splits_key_device();
756           workspace += workspace_elements_gk() * split_key_device() /
757               num_splits_key_device();
758         }
759       } else {
760         workspace = nullptr;
761       }
762 
763       // Advance pointers that depend on the total concatenated
764       // number of queries, as `num_queries` is modified in the block
765       // below
766       dropout_batch_head_rng_offset =
767           batch_id * (num_heads * num_queries * num_keys) +
768           head_id * (num_queries * num_keys);
769       logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH;
770 
771       if (cu_seqlens_q_ptr != nullptr) {
772         assert(cu_seqlens_k_ptr != nullptr);
773         cu_seqlens_q_ptr += batch_id;
774         cu_seqlens_k_ptr += batch_id;
775         int32_t q_start = cu_seqlens_q_ptr[0];
776         int32_t k_start = cu_seqlens_k_ptr[0];
777         int64_t q_next_start = cu_seqlens_q_ptr[1];
778         int64_t k_next_start = cu_seqlens_k_ptr[1];
779         assert(q_next_start - q_start <= num_queries);
780         assert(k_next_start - k_start <= num_keys);
781         num_queries = q_next_start - q_start;
782         num_keys = k_next_start - k_start;
783 
784         // Jump manually
785         batch_id = 0;
786 
787         query_ptr += q_start * q_strideM;
788         key_ptr += k_start * k_strideM;
789         value_ptr += k_start * v_strideM;
790         assert(bias_ptr == nullptr);
791         assert(grad_bias_ptr == nullptr);
792         output_ptr += q_start * o_strideM();
793         grad_output_ptr += q_start * gO_strideM;
794         delta_ptr += q_start;
795 
796         grad_query_ptr += q_start * gQ_strideM();
797         grad_key_ptr += k_start * gK_strideM();
798         grad_value_ptr += k_start * gV_strideM();
799       }
800 
801       query_ptr += batch_id * q_strideB + head_id * q_strideH;
802       key_ptr += batch_id * k_strideB + head_id * k_strideH;
803       value_ptr += batch_id * v_strideB + head_id * v_strideH;
804       if (bias_ptr != nullptr) {
805         bias_ptr += batch_id * bias_strideB + head_id * bias_strideH;
806       }
807       output_ptr += batch_id * o_strideB + head_id * o_strideH;
808       grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
809       delta_ptr += batch_id * delta_strideB + head_id * delta_strideH;
810 
811       grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH;
812       grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH;
813       grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH;
814       if (grad_bias_ptr != nullptr) {
815         grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH;
816       }
817 
818       // Some values are modified above
819       // Signal to the compiler that they are the same in all threads
820       // and can be stored in warp-uniform registers (Sm75+)
821       num_queries = warp_uniform(num_queries);
822       num_keys = warp_uniform(num_keys);
823       custom_mask_type = warp_uniform(custom_mask_type);
824 
825       query_ptr = warp_uniform(query_ptr);
826       key_ptr = warp_uniform(key_ptr);
827       value_ptr = warp_uniform(value_ptr);
828       bias_ptr = warp_uniform(bias_ptr);
829       logsumexp_ptr = warp_uniform(logsumexp_ptr);
830       output_ptr = warp_uniform(output_ptr);
831       grad_output_ptr = warp_uniform(grad_output_ptr);
832       delta_ptr = warp_uniform(delta_ptr);
833 
834       grad_query_ptr = warp_uniform(grad_query_ptr);
835       grad_key_ptr = warp_uniform(grad_key_ptr);
836       grad_value_ptr = warp_uniform(grad_value_ptr);
837       grad_bias_ptr = warp_uniform(grad_bias_ptr);
838 
839 #if 0
840       PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f",
841         int(blockIdx.z), int(blockIdx.y),
842         float(delta_ptr[0]),
843         float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]),
844         float(logsumexp_ptr[0])
845       )
846 #endif
847       return true;
848     }
849 
getBlocksGridAttentionBackwardKernel::Params850     __host__ dim3 getBlocksGrid() const {
851       return dim3(num_splits_key, num_heads, num_batches);
852     }
getThreadsGridAttentionBackwardKernel::Params853     __host__ dim3 getThreadsGrid() const {
854       return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1);
855     }
workspace_elements_gkAttentionBackwardKernel::Params856     CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const {
857       if (!kNeedsAccumGradK) {
858         return 0;
859       }
860       return num_splits_key * kBlockSizeJ *
861           align_up(head_dim, (int32_t)kBlockSizeI);
862     }
workspace_elements_gvAttentionBackwardKernel::Params863     CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const {
864       if (!kNeedsAccumGradV) {
865         return 0;
866       }
867       return num_splits_key * kBlockSizeJ *
868           align_up(head_dim_value, (int32_t)kBlockSizeI);
869     }
workspace_elements_gqAttentionBackwardKernel::Params870     CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const {
871       if (!kNeedsAccumGradQ) {
872         return 0;
873       }
874       int num_blocks = ceil_div(num_queries, kBlockSizeI);
875       int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN);
876       return num_blocks * num_cols * sizeof(GradQTempStorage) /
877           sizeof(output_accum_t);
878     }
workspace_strideBHAttentionBackwardKernel::Params879     CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const {
880       // Aligned on 128bits
881       return align_up(
882           workspace_elements_gk() + workspace_elements_gv() +
883               workspace_elements_gq(),
884           int64_t(4));
885     }
workspace_sizeAttentionBackwardKernel::Params886     CUTLASS_HOST_DEVICE int64_t workspace_size() const {
887       // Returns size of buffer we need to run this kernel
888       return num_batches * num_heads * workspace_strideBH() * sizeof(float);
889     }
should_zero_workspaceAttentionBackwardKernel::Params890     CUTLASS_HOST_DEVICE bool should_zero_workspace() const {
891       return num_splits_key > 1 || window_size > 0;
892     }
893   };
894 
895   // shared storage for keeping Zij matrix. not needed if we aren't using
896   // dropout, in which case we use an empty array to save shared memory
897   using ZijSharedStorage = typename cutlass::platform::conditional<
898       kApplyDropout,
899       typename MatmulQK::AccumulatorSharedStorage,
900       // dummy shared storage object that takes up no space.
901       typename cutlass::gemm::threadblock::AccumulatorSharedStorage<
902 #ifdef _WIN32
903           // windows builds throw the error:
904           // "type containing an unknown-size array is not allowed"
905           // if we try to make Zij shared storage zero-sized.
906           // To get around this just make it sized 1 on windows.
907           typename cutlass::gemm::GemmShape<1, 1, 0>,
908 #else
909           typename cutlass::gemm::GemmShape<0, 0, 0>,
910 #endif
911           typename MatmulQK::AccumulatorSharedStorage::Element,
912           typename MatmulQK::AccumulatorSharedStorage::Layout,
913           typename cutlass::MatrixShape<0, 0>>>::type;
914 
915   struct SharedStoragePrologue {
916     struct {
917       cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
918       typename MatmulQK::Mma::SharedStorageA mm_qk_k;
919     } persistent;
920     union {
921       struct {
922         // part1 - after Q.K / dV / dO.V
923         union {
924           // 1. efficient load of bias tile Bij, which is then applied to Pij
925           typename MatmulQK::BiasLoader::SmemTile bias;
926           // 4. store Pij. it is needed:
927           // - in dVj += (Pij.T * Zij) @ dOi
928           // - in dSij = Pij * (dPij - Di)
929           // 6. dVj += (Pij.T * Zij) @ dOi
930           // 10. write to fragment
931           typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
932         };
933         // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi
934         ZijSharedStorage zij;
935 
936         union {
937           // 2. prologue for dVj
938           // 6. workspace for dVj += (Pij.T * Zij) @ dOi
939           typename MatmulGradV::Mma::SharedStorage mm_gradV;
940           // 7. dVj epilogue
941           typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
942         };
943 
944         // 3. prologue for dPij_dropped
945         // 8. used in dPij_dropped = dOi @ Vj.T
946         typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
947       } part1;
948 
949       struct {
950         // part2 - dQ
951         union {
952           typename MatmulQK::AccumulatorSharedStorage
953               tmpT_shared_storage; // (from part1)
954           typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
955         };
956         typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
957         typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
958         union {
959           // store dB = dSij to global memory
960           typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
961           typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
962         };
963 
964       } part2;
965 
966       struct {
967         // part3 - after last iteration on dQ's epilogue / dK
968         union {
969           typename MatmulQK::AccumulatorSharedStorage
970               tmpT_shared_storage; // (from part1)
971           typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
972         };
973         typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
974         typename MatmulGradQ::DefaultEpilogue::SharedStorage
975             gradQ_epilogue_lastIter;
976 
977         typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
978       } part3;
979 
980       struct {
981         // part4 - after last iteration on dK's epilogue / preload next K.Q_t
982         typename MatmulQK::Mma::SharedStorageB mm_qk_q;
983 
984         // If we reach end of current key, dump RF->gmem with "final" epilogues
985         typename MatmulGradK::DefaultEpilogue::SharedStorage
986             gradK_epilogue_final;
987         typename MatmulGradV::DefaultEpilogue::SharedStorage
988             gradV_epilogue_final;
989       } part4;
990     };
print_sizeAttentionBackwardKernel::SharedStoragePrologue991     static void print_size() {
992       // Field size
993 #define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f)))
994 
995       printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue)));
996       printf("  persistent: %db\n", FSZ(persistent));
997       printf("    mm_qk_k: %db\n", FSZ(persistent.mm_qk_k));
998       printf("  part1: %db\n", FSZ(part1));
999       printf("    bias: %db\n", FSZ(part1.bias));
1000       printf("    attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage));
1001       printf("    zij: %db\n", FSZ(part1.zij));
1002       printf("    mm_gradV: %db\n", FSZ(part1.mm_gradV));
1003       printf("    gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue));
1004       printf("    mm_doivj: %db\n", FSZ(part1.mm_doivj));
1005       printf("  part2: %db\n", FSZ(part2));
1006       printf("    tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage));
1007       printf("    tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage));
1008       printf("    mm_gradK: %db\n", FSZ(part2.mm_gradK));
1009       printf("    mm_gradQ: %db\n", FSZ(part2.mm_gradQ));
1010       printf("    gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue));
1011       printf("    gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue));
1012       printf("  part3: %db\n", FSZ(part3));
1013       printf("    tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage));
1014       printf("  part4: %db\n", FSZ(part4));
1015       printf("    mm_qk_q: %db\n", FSZ(part4.mm_qk_q));
1016       printf(
1017           "    gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final));
1018       printf(
1019           "    gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final));
1020     }
1021 // ===========================================
1022 #define FIELD(INSIDE_STRUCT, FIELDNAME) \
1023   CUTLASS_DEVICE auto& FIELDNAME() {    \
1024     return INSIDE_STRUCT.FIELDNAME;     \
1025   }
1026 
1027     FIELD(persistent, di)
1028     FIELD(persistent, mm_qk_k)
1029     FIELD(part1, bias)
1030     FIELD(part1, attn_shared_storage)
1031     FIELD(part1, zij)
1032     FIELD(part1, mm_gradV)
1033     FIELD(part1, gradV_epilogue)
1034     FIELD(part1, mm_doivj)
1035     FIELD(part2, mm_gradK)
1036     FIELD(part2, mm_gradQ)
1037     FIELD(part2, gradB_epilogue)
1038     FIELD(part2, gradQ_epilogue)
1039     FIELD(part2, tmp_shared_storage)
1040     FIELD(part3, tmpT_shared_storage)
1041     FIELD(part3, gradQ_epilogue_lastIter)
1042     FIELD(part3, gradK_epilogue)
1043     FIELD(part4, mm_qk_q)
1044     FIELD(part4, gradK_epilogue_final)
1045     FIELD(part4, gradV_epilogue_final)
1046   };
1047 
1048   struct SharedStorageNoPrologue {
1049     struct {
1050       cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
1051     } persistent;
1052     union {
1053       struct {
1054         // part1 - Q.K matmul
1055         typename MatmulQK::Mma::SharedStorageA mm_qk_k;
1056         typename MatmulQK::Mma::SharedStorageB mm_qk_q;
1057       } part1;
1058 
1059       struct {
1060         // part2 - compute gradV
1061         union {
1062           // 1. efficient load of bias tile Bij, which is then applied to Pij
1063           typename MatmulQK::BiasLoader::SmemTile bias;
1064           // 2. store Pij to shared memory. it is needed:
1065           // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
1066           // - in next step where it is used in dSij = Pij * (dPij - Di)
1067           typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
1068         };
1069         // 3. store Zij. it is needed in this step, where it is used
1070         // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are
1071         // loaded for the computation of dVj.
1072         ZijSharedStorage zij;
1073 
1074         union {
1075           typename MatmulGradV::Mma::SharedStorage mm_gradV;
1076           typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
1077         };
1078       } part2;
1079 
1080       struct {
1081         // part3 - DO.V matmul
1082         union {
1083           // first compute dPij = (dOi @ Vj.T) * Zij
1084           // and dSij = Pij * (dPij - Di)
1085           struct {
1086             // (from part2) - Pij for computing dSij = Pij * (dPij - Di)
1087             typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
1088             // matmul to compute dOiVj
1089             typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
1090           };
1091           // then store dB = dSij to global memory
1092           typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
1093         };
1094       } part3;
1095 
1096       struct {
1097         // part4 - compute gradQ
1098         typename MatmulQK::AccumulatorSharedStorage
1099             tmpT_shared_storage; // (from part2)
1100         typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
1101         union {
1102           typename MatmulGradQ::Mma::SharedStorage mm_gradQ;
1103           typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
1104           typename MatmulGradQ::DefaultEpilogue::SharedStorage
1105               gradQ_epilogue_lastIter;
1106         };
1107       } part4;
1108 
1109       struct {
1110         // part5 - compute gradK
1111         typename MatmulQK::AccumulatorSharedStorage
1112             tmpT_shared_storage; // (from part2)
1113         typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
1114         union {
1115           typename MatmulGradK::Mma::SharedStorage mm_gradK;
1116           typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
1117         };
1118       } part5;
1119 
1120       struct {
1121         // part6 - store RF accumulated into gmem
1122         typename MatmulGradK::DefaultEpilogue::SharedStorage
1123             gradK_epilogue_final;
1124         typename MatmulGradV::DefaultEpilogue::SharedStorage
1125             gradV_epilogue_final;
1126       } part6;
1127     };
print_sizeAttentionBackwardKernel::SharedStorageNoPrologue1128     static void print_size() {
1129 #define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f)))
1130       printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue)));
1131       printf("  persistent: %db\n", FIELD_SIZEOF(persistent));
1132       printf("  part1: %db\n", FIELD_SIZEOF(part1));
1133       printf("  part2: %db\n", FIELD_SIZEOF(part2));
1134       printf("  part3: %db\n", FIELD_SIZEOF(part3));
1135       printf("  part4: %db\n", FIELD_SIZEOF(part4));
1136       printf("  part5: %db\n", FIELD_SIZEOF(part5));
1137       printf("  part6: %db\n", FIELD_SIZEOF(part6));
1138     }
1139 // ===========================================
1140 #define FIELD(INSIDE_STRUCT, FIELDNAME) \
1141   CUTLASS_DEVICE auto& FIELDNAME() {    \
1142     return INSIDE_STRUCT.FIELDNAME;     \
1143   }
1144 
1145     FIELD(persistent, di)
1146     FIELD(part1, mm_qk_k)
1147     FIELD(part1, mm_qk_q)
1148     FIELD(part2, bias)
1149     FIELD(part2, attn_shared_storage)
1150     FIELD(part2, zij)
1151     FIELD(part2, mm_gradV)
1152     FIELD(part2, gradV_epilogue)
1153     FIELD(part3, mm_doivj)
1154     FIELD(part3, gradB_epilogue)
1155     FIELD(part4, tmpT_shared_storage)
1156     FIELD(part4, tmp_shared_storage)
1157     FIELD(part4, mm_gradQ)
1158     FIELD(part4, gradQ_epilogue)
1159     FIELD(part4, gradQ_epilogue_lastIter)
1160     FIELD(part5, mm_gradK)
1161     FIELD(part5, gradK_epilogue)
1162     FIELD(part6, gradK_epilogue_final)
1163     FIELD(part6, gradV_epilogue_final)
1164   };
1165 
1166   using SharedStorage = typename cutlass::platform::conditional<
1167       kPreload,
1168       SharedStoragePrologue,
1169       SharedStorageNoPrologue>::type;
1170 
1171   struct OutputFragments {
1172     typename MatmulGradV::Mma::FragmentC gradV;
1173     typename MatmulGradK::Mma::FragmentC gradK;
1174 
clearAttentionBackwardKernel::OutputFragments1175     CUTLASS_DEVICE void clear() {
1176       gradV.clear();
1177       gradK.clear();
1178     }
1179   };
1180 
check_supportedAttentionBackwardKernel1181   static bool __host__ check_supported(Params const& p) {
1182     CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
1183     CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
1184     CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
1185     CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
1186     CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
1187     CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
1188     TORCH_CHECK(
1189         p.num_heads <= 1 || p.lse_strideH % 8 == 0,
1190         "LSE is not correctly aligned (strideH)");
1191     TORCH_CHECK(
1192         p.num_batches <= 1 || p.lse_strideB % 8 == 0,
1193         "LSE is not correctly aligned (strideB)");
1194     TORCH_CHECK(
1195         p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
1196         "query is not correctly aligned (strideH)");
1197     TORCH_CHECK(
1198         p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0,
1199         "key is not correctly aligned (strideH)");
1200     TORCH_CHECK(
1201         p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0,
1202         "value is not correctly aligned (strideH)");
1203     TORCH_CHECK(
1204         p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
1205         "query is not correctly aligned (strideB)");
1206     TORCH_CHECK(
1207         p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
1208         "key is not correctly aligned (strideB)");
1209     TORCH_CHECK(
1210         p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0,
1211         "value is not correctly aligned (strideB)");
1212     TORCH_CHECK(
1213         p.q_strideM % kMinimumAlignment == 0,
1214         "query is not correctly aligned (strideM)");
1215     TORCH_CHECK(
1216         p.k_strideM % kMinimumAlignment == 0,
1217         "key is not correctly aligned (strideM)");
1218     TORCH_CHECK(
1219         p.v_strideM % kMinimumAlignment == 0,
1220         "value is not correctly aligned (strideM)");
1221     if (p.bias_ptr) {
1222       TORCH_CHECK(
1223           p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
1224           "attn_bias is not correctly aligned (strideB). ",
1225           "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a "
1226           "multiple of ", kMinimumAlignment, ".");
1227       TORCH_CHECK(
1228           p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
1229           "attn_bias is not correctly aligned (strideH) ."
1230           "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
1231           "multiple of ", kMinimumAlignment, ".");
1232       TORCH_CHECK(
1233           p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0,
1234           "attn_bias is not correctly aligned (strideM). "
1235           "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ",
1236           "multiple of ", kMinimumAlignment, ".");
1237     }
1238     if (p.grad_bias_ptr) {
1239       TORCH_CHECK(
1240           p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0,
1241           "attn_bias.grad is not correctly aligned (strideB)");
1242       TORCH_CHECK(
1243           p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0,
1244           "attn_bias.grad is not correctly aligned (strideH)");
1245       TORCH_CHECK(
1246           p.gB_strideM % kMinimumAlignment == 0,
1247           "attn_bias.grad is not correctly aligned (strideM)");
1248     }
1249     TORCH_CHECK(
1250         !(p.cu_seqlens_q_ptr && p.bias_ptr),
1251         "CuSeqlen + bias not implemented yet");
1252     TORCH_CHECK(
1253         p.custom_mask_type < NumCustomMaskTypes,
1254         "Invalid value for `custom_mask_type`");
1255     TORCH_CHECK(
1256         p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f,
1257         "Invalid value for `dropout_prob`");
1258     TORCH_CHECK(
1259         kApplyDropout || p.dropout_prob == 0.0f,
1260         "Set `kApplyDropout`=True to support `dropout_prob > 0`");
1261     TORCH_CHECK(p.head_dim > 0, "Invalid value for `head_dim`");
1262     TORCH_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`");
1263     TORCH_CHECK(p.num_queries > 0, "Invalid value for `num_queries`");
1264     TORCH_CHECK(p.num_keys > 0, "Invalid value for `num_keys`");
1265     TORCH_CHECK(p.num_heads > 0, "Invalid value for `num_heads`");
1266     TORCH_CHECK(p.num_batches > 0, "Invalid value for `num_batches`");
1267     TORCH_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`");
1268     TORCH_CHECK(
1269         p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`");
1270     if (kKeysQueriesAlignedToBlockSize) {
1271       TORCH_CHECK(
1272           p.cu_seqlens_k_ptr == nullptr,
1273           "This kernel does not support cu_seqlen");
1274       TORCH_CHECK(
1275           p.cu_seqlens_q_ptr == nullptr,
1276           "This kernel does not support cu_seqlen");
1277       TORCH_CHECK(
1278           p.num_queries % kBlockSizeI == 0,
1279           "kKeysQueriesAlignedToBlockSize condition not respected");
1280       TORCH_CHECK(
1281           p.num_keys % kBlockSizeJ == 0,
1282           "kKeysQueriesAlignedToBlockSize condition not respected");
1283     }
1284     TORCH_CHECK(
1285         kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled");
1286     TORCH_CHECK(
1287         p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)");
1288     TORCH_CHECK(
1289         p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ),
1290         "Invalid `num_splits_key` (",
1291         p.num_splits_key,
1292         ") - too large for `num_keys` = ",
1293         p.num_keys);
1294     if (p.window_size != 0) {
1295       TORCH_CHECK(
1296           p.custom_mask_type != NoCustomMask,
1297           "LocalAttention only supported in causal mode");
1298     }
1299     return true;
1300   }
1301 
attention_kernelAttentionBackwardKernel1302   static CUTLASS_DEVICE void attention_kernel(Params p) {
1303     extern __shared__ char smem_buffer[];
1304     SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
1305 
1306     uint16_t thread_id = threadIdx.x;
1307     uint8_t warp_id = warp_uniform(thread_id / 32);
1308     uint8_t lane_id = thread_id % 32;
1309 
1310     int32_t key_start = p.split_key_device() * kBlockSizeJ;
1311     if (key_start >= p.num_keys) {
1312       return;
1313     }
1314     if (kPrologueQK) {
1315       int32_t query_start = getQueryStart(p, key_start);
1316       prologueQkNextIteration<true>(
1317           shared_storage, p, query_start, key_start, warp_id, lane_id);
1318     }
1319 
1320     // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
1321     if (kKernelComputesDelta) {
1322       constexpr int kOptimalElements =
1323           128 / cutlass::sizeof_bits<scalar_t>::value;
1324       if (p.head_dim_value % kOptimalElements == 0) {
1325         for (int query_start = 0; query_start < p.num_queries;
1326              query_start += kBlockSizeI) {
1327           computeDelta<kOptimalElements>(p, query_start, warp_id, lane_id);
1328         }
1329       } else {
1330         for (int query_start = 0; query_start < p.num_queries;
1331              query_start += kBlockSizeI) {
1332           computeDelta<1>(p, query_start, warp_id, lane_id);
1333         }
1334       }
1335       __syncthreads();
1336     }
1337 
1338     OutputFragments output_frags;
1339 
1340     curandStatePhilox4_32_10_t rng_state_init;
1341 
1342     if (kApplyDropout) {
1343       // See Note [Seed and Offset Device]
1344       auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
1345       // each element of the attention matrix P with shape
1346       // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
1347       // offset in RNG sequence. we initialize the RNG state with offset that
1348       // starts at the beginning of a (n_queries, n_keys) matrix for this
1349       // block's batch_id and head_id
1350       // initializing rng state is very expensive, so we run once per kernel,
1351       // rather than once per iteration. each iteration takes a copy of the
1352       // initialized RNG state and offsets it as needed.
1353       curand_init(
1354           std::get<0>(seeds),
1355           0,
1356           std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
1357           &rng_state_init);
1358     }
1359 
1360     CUTLASS_PRAGMA_UNROLL
1361     for (; key_start < p.num_keys;
1362          key_start += p.num_splits_key_device() * kBlockSizeJ) {
1363       output_frags.clear();
1364 
1365       int32_t next_key = key_start;
1366       int32_t query_start = getQueryStart(p, key_start);
1367       while (next_key == key_start && query_start < p.num_queries) {
1368         // This line here
1369         // vvvvvvvvvvvvvv
1370         warp_id = warp_uniform(warp_id);
1371         // ^^^^^^^^^^^^^^
1372         // ... makes everything use less RF and be 10% faster. Why?
1373         // I don't know. My theory is that it forces `nvcc` to
1374         // re-compute indices, offsets etc... and not keep them
1375         // from the previous iteration, which prevents MASSIVE
1376         // register spilling.
1377 
1378         processBlockIJ<kKeysQueriesAlignedToBlockSize>(
1379             shared_storage,
1380             output_frags,
1381             p,
1382             query_start,
1383             key_start,
1384             rng_state_init,
1385             warp_id,
1386             lane_id);
1387 
1388         int32_t next_query;
1389         incrIteration(p, query_start, key_start, next_query, next_key);
1390         query_start = next_query;
1391       }
1392       if (kOutputInRF) {
1393         writeFragsToGmem<kKeysQueriesAlignedToBlockSize>(
1394             shared_storage, output_frags, p, key_start, warp_id, lane_id);
1395       } else if (getQueryStart(p, key_start) >= p.num_queries) {
1396         zfillGradKV<kKeysQueriesAlignedToBlockSize>(
1397             p, key_start, warp_id, lane_id);
1398       }
1399       __syncthreads();
1400     }
1401   }
1402 
1403   template <bool skipBoundsChecks>
zfillGradKVAttentionBackwardKernel1404   static CUTLASS_DEVICE void zfillGradKV(
1405       Params const& p,
1406       int32_t key_start,
1407       uint8_t warp_id,
1408       uint8_t lane_id) {
1409     constexpr int kThreadsPerKey = 8;
1410     constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
1411     static_assert(kBlockSizeJ % kParallelKeys == 0, "");
1412     // This function is not really optimized, but should rarely be used
1413     // It's only used when some keys are "useless" and don't attend to
1414     // any query, due to causal masking
1415 
1416     int thread_id = 32 * warp_id + lane_id;
1417     int k_shift = lane_id % kThreadsPerKey;
1418 
1419     CUTLASS_PRAGMA_UNROLL
1420     for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
1421       int key = key_start + j + (thread_id / kThreadsPerKey);
1422       if (!skipBoundsChecks && key >= p.num_keys) {
1423         continue;
1424       }
1425       auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
1426       auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
1427 
1428       for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
1429         gv_ptr[k] = scalar_t(0);
1430       }
1431       for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) {
1432         gk_ptr[k] = scalar_t(0);
1433       }
1434     }
1435   }
1436 
1437   template <bool skipBoundsChecks>
processBlockIJAttentionBackwardKernel1438   static CUTLASS_DEVICE void processBlockIJ(
1439       SharedStorage& shared_storage,
1440       OutputFragments& output_frags,
1441       Params& p,
1442       int32_t query_start,
1443       int32_t key_start,
1444       const curandStatePhilox4_32_10_t& curand_state_init,
1445       uint8_t warp_id,
1446       uint8_t lane_id) {
1447     cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
1448         dropout_keep_mask_doivj;
1449     dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
1450     const float dropout_scale =
1451         kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;
1452 
1453     cutlass::MatrixCoord no_offset{0, 0};
1454     accum_t scale = p.scale;
1455     int16_t thread_id = 32 * warp_id + lane_id;
1456 
1457     auto rematerializeThreadIds = [&]() {
1458       // Prevents `nvcc` from keeping values deduced from
1459       // `thread_id`, `warp_id`, ... in RF - to reduce register pressure
1460       warp_id = warp_uniform(thread_id / 32);
1461       lane_id = thread_id % 32;
1462       thread_id = 32 * warp_id + lane_id;
1463     };
1464 
1465     bool isFirstQuery = (query_start == getQueryStart(p, key_start));
1466     int32_t next_query, next_key;
1467     incrIteration(p, query_start, key_start, next_query, next_key);
1468     bool isLastQuery = next_key != key_start;
1469 
1470     accum_t di_rf = accum_t(0);
1471     if (thread_id < kBlockSizeI) {
1472       if (query_start + thread_id < p.num_queries) {
1473         di_rf = p.delta_ptr[query_start + thread_id];
1474       }
1475       shared_storage.di()[thread_id] = di_rf;
1476     }
1477 
1478     int32_t num_queries_in_block = skipBoundsChecks
1479         ? MatmulQK::Mma::Shape::kN
1480         : warp_uniform(cutlass::fast_min(
1481               (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start));
1482     int32_t num_keys_in_block = skipBoundsChecks
1483         ? MatmulQK::Mma::Shape::kM
1484         : warp_uniform(cutlass::fast_min(
1485               (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start));
1486 
1487     auto prologueGradV = [&](int col) {
1488       typename MatmulGradV::Mma::IteratorB iterator_dO(
1489           {int32_t(p.gO_strideM)},
1490           const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM + col),
1491           {num_queries_in_block, p.head_dim_value - col},
1492           thread_id,
1493           no_offset);
1494       MatmulGradV::Mma::prologue(
1495           shared_storage.mm_gradV(),
1496           iterator_dO,
1497           thread_id,
1498           num_queries_in_block);
1499     };
1500     auto prologueGradQ = [&](int col) {
1501       typename MatmulGradQ::Mma::IteratorB iterator_K(
1502           {int32_t(p.k_strideM)},
1503           const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM + col),
1504           {num_keys_in_block, p.head_dim - col},
1505           thread_id,
1506           no_offset);
1507       MatmulGradQ::Mma::prologue(
1508           shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block);
1509     };
1510     auto prologueGradK = [&](int col) {
1511       typename MatmulGradK::Mma::IteratorB iterator_Q(
1512           {int32_t(p.q_strideM)},
1513           const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM + col),
1514           {num_queries_in_block, p.head_dim - col},
1515           thread_id,
1516           no_offset);
1517       MatmulGradK::Mma::prologue(
1518           shared_storage.mm_gradK(),
1519           iterator_Q,
1520           thread_id,
1521           num_queries_in_block);
1522     };
1523     auto prologueDOV = [&]() {
1524       typename MatmulDOIVJ::Mma::IteratorA iterator_A(
1525           {int32_t(p.gO_strideM)},
1526           const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM),
1527           {num_queries_in_block, p.head_dim_value},
1528           thread_id,
1529           no_offset);
1530       typename MatmulDOIVJ::Mma::IteratorB iterator_B(
1531           {int32_t(p.v_strideM)},
1532           const_cast<scalar_t*>(p.value_ptr + key_start * p.v_strideM),
1533           {p.head_dim_value, num_keys_in_block},
1534           thread_id,
1535           no_offset);
1536       MatmulDOIVJ::Mma::prologue(
1537           shared_storage.mm_doivj(),
1538           iterator_A,
1539           iterator_B,
1540           thread_id,
1541           p.head_dim_value);
1542     };
1543 
1544     /////////////////////////////////////////////////////////////////////////////////////////////////
1545     // MatmulQK
1546     /////////////////////////////////////////////////////////////////////////////////////////////////
1547     {
1548       using Mma = typename MatmulQK::Mma;
1549 
1550       cutlass::gemm::GemmCoord problem_size(
1551           num_keys_in_block,
1552           num_queries_in_block,
1553           p.head_dim // k
1554       );
1555 
1556       // k_j
1557       typename Mma::IteratorA iterator_A(
1558           {int32_t(p.k_strideM)},
1559           const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM),
1560           {problem_size.m(), problem_size.k()},
1561           thread_id,
1562           no_offset);
1563 
1564       // q_i.transpose(-2, -1)
1565       typename Mma::IteratorB iterator_B(
1566           {int32_t(p.q_strideM)},
1567           const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM),
1568           {problem_size.k(), problem_size.n()},
1569           thread_id,
1570           no_offset);
1571 
1572       Mma mma(
1573           shared_storage.mm_qk_k(),
1574           shared_storage.mm_qk_q(),
1575           thread_id,
1576           warp_id,
1577           lane_id);
1578 
1579       typename Mma::FragmentC accum;
1580 
1581       accum.clear();
1582 
1583       auto gemm_k_iterations =
1584           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
1585 
1586       // Compute threadblock-scoped matrix multiply-add
1587       mma.set_prologue_done(kPrologueQK);
1588       mma.set_zero_outside_bounds(!skipBoundsChecks);
1589       mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
1590       accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
1591 
1592       // Epilogue: add LSE + exp and store that to our shared memory buffer
1593       // shmem <- (matmul_result -
1594       // logsumexp[i_start:i_end].unsqueeze(1)).exp()
1595       int warp_idx_mn_0 =
1596           warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
1597       auto output_tile_coords = cutlass::MatrixCoord{
1598           warp_idx_mn_0 % Mma::Base::WarpCount::kM,
1599           warp_idx_mn_0 / Mma::Base::WarpCount::kM};
1600 
1601       // apply bias if applicable
1602       if (p.bias_ptr != nullptr) {
1603         // load bias tile Bij into shared memory
1604         typename MatmulQK::BiasLoader::GmemTileIterator bias_iter(
1605             {cutlass::layout::RowMajor(p.bias_strideM)},
1606             const_cast<scalar_t*>(p.bias_ptr + query_start * p.bias_strideM + key_start),
1607             {num_queries_in_block, num_keys_in_block},
1608             thread_id);
1609         cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
1610             shared_storage.bias().data(),
1611             cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM));
1612         typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter(
1613             bias_tensor_ref, thread_id);
1614         MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter);
1615 
1616         // Pij += Bij, where Pij is in register fragment and Bij is in shmem
1617         auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
1618             lane_id, warp_id, output_tile_coords);
1619         MatmulQK::AccumLambdaIterator::iterateRows(
1620             lane_offset,
1621             [&](int accum_n) {},
1622             [&](int accum_m, int accum_n, int idx) {
1623               // remember we are transposed
1624               accum[idx] += bias_tensor_ref.at({accum_n, accum_m});
1625             },
1626             [&](int accum_n) {});
1627       }
1628 
1629       // Apply mask
1630       if (p.custom_mask_type == CausalFromTopLeft ||
1631           p.custom_mask_type == CausalFromBottomRight) {
1632         auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
1633             lane_id, warp_id, output_tile_coords);
1634         int shift = query_start - key_start;
1635         if (p.custom_mask_type == CausalFromBottomRight) {
1636           shift += p.num_keys - p.num_queries;
1637         }
1638         // current_key = key_start + accum_m
1639         // current_query = query_start + accum_n
1640         // mask if: `current_key > current_query`
1641         MatmulQK::AccumLambdaIterator::iterateRows(
1642             lane_offset,
1643             [&](int accum_m) {},
1644             [&](int accum_m, int accum_n, int idx) {
1645               if (accum_m > accum_n + shift) {
1646                 accum[idx] =
1647                     -cutlass::platform::numeric_limits<accum_t>::infinity();
1648               }
1649             },
1650             [&](int accum_m) {});
1651       }
1652       if (p.window_size > 0) {
1653         auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
1654             lane_id, warp_id, output_tile_coords);
1655         int shift = query_start - key_start - p.window_size;
1656         // current_key = key_start + accum_m
1657         // current_query = query_start + accum_n
1658         // mask if: `current_key < current_query - window_size`
1659         // if accum_m < accum_n + query_start - window_size - key_start
1660 
1661         MatmulQK::AccumLambdaIterator::iterateRows(
1662             lane_offset,
1663             [&](int accum_m) {},
1664             [&](int accum_m, int accum_n, int idx) {
1665               if (accum_m <= accum_n + shift) {
1666                 accum[idx] =
1667                     -cutlass::platform::numeric_limits<accum_t>::infinity();
1668               }
1669             },
1670             [&](int accum_m) {});
1671       }
1672       __syncthreads();
1673       if (kPrologueGV) {
1674         prologueGradV(0);
1675       }
1676       if (kPrologueDOV) {
1677         prologueDOV();
1678       }
1679 
1680       MatmulQK::B2bGemm::accumApplyLSEToSmem(
1681           shared_storage.attn_shared_storage(),
1682           accum,
1683           p.logsumexp_ptr + query_start,
1684           problem_size.n(),
1685           thread_id,
1686           warp_id,
1687           lane_id,
1688           output_tile_coords);
1689 #if 0
1690       auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref();
1691       PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT);
1692 #endif
1693 
1694       // if we are using dropout, compute Zij, writing it to shared memory.
1695       // each element of Zij is:
1696       // - 0 with probability dropout_p
1697       // - 1 / (1 - dropout_p) with probability 1 - dropout_p
1698       if (kApplyDropout) {
1699         auto zij = shared_storage.zij().accum_ref();
1700         // each thread generates a contiguous sequence of elements in Zij, all
1701         // in the same row. the reason they have to come from the same row is
1702         // that sampling random numbers from a contiguous random number sequence
1703         // is much more efficient than jumping around, and the linear offset of
1704         // each element of Z (the global matrix) maps to an offset in a random
1705         // number sequence. for Z, the end of a row and the beginning of the
1706         // next have adjacent offsets, but for Zij (tile of global matrix), this
1707         // is not necessarily the case.
1708         // We must fill the entire `zij` shmem with values (even out of bounds
1709         // on the K-dimension) otherwise we can get NaNs during the GEMM
1710         const int kQueriesPerBlock = kBlockSizeI;
1711         const int threads_per_row = cutlass::fast_min(
1712             int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block);
1713         const int elts_per_thread = cutlass::round_nearest(
1714             cutlass::ceil_div(num_keys_in_block, threads_per_row), 4);
1715 
1716         const int thread_i = thread_id / threads_per_row;
1717         const int thread_start_j =
1718             (thread_id % threads_per_row) * elts_per_thread;
1719 
1720         if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) {
1721           curandStatePhilox4_32_10_t curand_state = curand_state_init;
1722           skipahead(
1723               (query_start + thread_i) * p.num_keys +
1724                   (key_start + thread_start_j),
1725               &curand_state);
1726 
1727           // generate elements of Zij, 4 elements at a time
1728           for (int zij_start_col_idx = thread_start_j; zij_start_col_idx <
1729                cutlass::fast_min<int32_t>(thread_start_j + elts_per_thread,
1730                                           num_keys_in_block);
1731                zij_start_col_idx += 4) {
1732             const float4 rand_uniform_quad = curand_uniform4(&curand_state);
1733 
1734             CUTLASS_PRAGMA_UNROLL
1735             for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
1736               // we'll write Zij transposed since attention is also transposed
1737               // during the matmul to compute dV.
1738               zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) =
1739                   (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob
1740                   ? scalar_t(dropout_scale)
1741                   : scalar_t(0);
1742             }
1743           }
1744         }
1745         __syncthreads();
1746 #if 0
1747         PRINT_TENSOR4x4_T0_L0("zij", zij);
1748         PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4);
1749 #endif
1750 
1751         // Save mask for later DOIVJ matmul
1752 
1753         int warp_idx_mn_0 = warp_id %
1754             (MatmulDOIVJ::Mma::Base::WarpCount::kM *
1755              MatmulDOIVJ::Mma::Base::WarpCount::kN);
1756         auto output_tile_coords_doivj = cutlass::MatrixCoord{
1757             warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM,
1758             warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM};
1759         auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset(
1760             lane_id, warp_id, output_tile_coords_doivj);
1761         MatmulDOIVJ::AccumLambdaIterator::iterateRows(
1762             lane_offset,
1763             [&](int accum_m) {},
1764             [&](int accum_m /*q*/, int accum_n /*k*/, int idx) {
1765               if (zij.at({accum_n, accum_m}) == scalar_t(0)) {
1766                 dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0};
1767               }
1768             },
1769             [&](int accum_m) {});
1770       }
1771       __syncthreads();
1772     }
1773     rematerializeThreadIds();
1774 
1775     /////////////////////////////////////////////////////////////////////////////////////////////////
1776     // GradV matmul
1777     //
1778     // grad_v[j_start:j_end] += attn_T @ do_i
1779     /////////////////////////////////////////////////////////////////////////////////////////////////
1780     constexpr bool kSingleIterationGradV =
1781         kMaxK <= MatmulGradV::ThreadblockShape::kN;
1782     for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value);
1783          col += MatmulGradV::ThreadblockShape::kN) {
1784       using Mma = typename MatmulGradV::Mma;
1785       using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
1786 
1787       cutlass::gemm::GemmCoord problem_size(
1788           num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
1789       auto createEpilogueIter = [&]() {
1790         return typename MatmulGradV::OutputTileIterator(
1791             typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
1792             p.grad_value_ptr + key_start * p.gV_strideM() + col,
1793             {num_keys_in_block, p.head_dim_value - col},
1794             thread_id);
1795       };
1796       typename Mma::IteratorB iterator_B(
1797           {int32_t(p.gO_strideM)},
1798           const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM + col),
1799           {num_queries_in_block, p.head_dim_value - col},
1800           thread_id,
1801           no_offset);
1802 
1803       // if dropout: dVj += (Pij.T * Zij) @ dOi
1804       // otherwise:  dVj += Pij.T @ dOi
1805       Mma mma(
1806           // operand A: Pij.T
1807           shared_storage.attn_shared_storage().accum_ref(),
1808           // operand A_scale Zij.T:
1809           // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T
1810           // which is computed on the fly as fragments of Pij.T are loaded in
1811           shared_storage.zij().accum_ref(),
1812           // operand B: dOi - which was loaded into shared memory previously
1813           // when we computed dVj
1814           shared_storage.mm_gradV().operand_B_ref(),
1815           thread_id,
1816           warp_id,
1817           lane_id);
1818 
1819       int storage_id = col / MatmulGradV::ThreadblockShape::kN;
1820       AccumTileGmem gmem_tile{
1821           p.workspace_gv + storage_id * AccumTileGmem::kElementsStored};
1822       if (!kOutputInRF) {
1823         if (isFirstQuery || !kNeedsAccumGradV) {
1824           output_frags.gradV.clear();
1825         } else {
1826           gmem_tile.load(output_frags.gradV, thread_id);
1827         }
1828       }
1829       mma.set_prologue_done(kPrologueGV);
1830 
1831       auto gemm_k_iterations =
1832           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
1833 
1834       // Compute threadblock-scoped matrix multiply-add
1835       __syncthreads();
1836 
1837       mma(gemm_k_iterations,
1838           output_frags.gradV,
1839           iterator_B,
1840           output_frags.gradV);
1841       __syncthreads();
1842       if (kPrologueGV && !kSingleIterationGradV &&
1843           col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) {
1844         prologueGradV(col + MatmulGradV::ThreadblockShape::kN);
1845       }
1846 
1847       if (!kOutputInRF) {
1848         if (kNeedsAccumGradV && !isLastQuery) {
1849           gmem_tile.store(output_frags.gradV, thread_id);
1850         } else {
1851           accumulateInGmem<MatmulGradV>(
1852               shared_storage.gradV_epilogue(),
1853               output_frags.gradV,
1854               createEpilogueIter(),
1855               isFirstQuery || kNeedsAccumGradV,
1856               warp_id,
1857               lane_id);
1858         }
1859       }
1860     }
1861     __syncthreads();
1862 
1863     /////////////////////////////////////////////////////////////////////////////////////////////////
1864     // MatmulDOIVJ
1865     /////////////////////////////////////////////////////////////////////////////////////////////////
1866     {
1867       using Mma = typename MatmulDOIVJ::Mma;
1868       // do_i
1869       typename Mma::IteratorA iterator_A(
1870           {int32_t(p.gO_strideM)},
1871           const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM),
1872           {num_queries_in_block, p.head_dim_value},
1873           thread_id,
1874           no_offset);
1875 
1876       // v_j.transpose(-2, -1)
1877       typename Mma::IteratorB iterator_B(
1878           {int32_t(p.v_strideM)},
1879           const_cast<scalar_t*>(p.value_ptr + key_start * p.v_strideM),
1880           {p.head_dim_value, num_keys_in_block},
1881           thread_id,
1882           no_offset);
1883 
1884       Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id);
1885       mma.set_prologue_done(kPrologueDOV);
1886       mma.set_zero_outside_bounds(!skipBoundsChecks);
1887 
1888       typename Mma::FragmentC accum;
1889 
1890       accum.clear();
1891 
1892       auto gemm_k_iterations =
1893           (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK;
1894 
1895       // Compute threadblock-scoped matrix multiply-add
1896       mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
1897       __syncthreads();
1898       if (kPrologueGQ) {
1899         prologueGradQ(0);
1900       }
1901       if (kPrologueGK) {
1902         prologueGradK(0);
1903       }
1904 
1905       int warp_idx_mn_0 =
1906           warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
1907       auto output_tile_coords = cutlass::MatrixCoord{
1908           warp_idx_mn_0 % Mma::Base::WarpCount::kM,
1909           warp_idx_mn_0 / Mma::Base::WarpCount::kM};
1910       // TODO: This must be terribly inefficient. There must be a better way
1911       // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
1912       // attn_shared_storage  [smem] <- tmp.T
1913       // tmp_shared_storage [smem] <- tmp
1914       {
1915         using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator;
1916         auto lane_offset = LambdaIterator::get_lane_offset(
1917             lane_id, warp_id, output_tile_coords);
1918         // if dropout was used, compute dPij = dPij_dropped * Zij
1919         if (kApplyDropout) {
1920           LambdaIterator::iterateRows(
1921               lane_offset,
1922               [&](int accum_m) {},
1923               [&](int accum_m, int accum_n, int idx) {
1924                 if (dropout_keep_mask_doivj[idx].get()) {
1925                   accum[idx] *= dropout_scale;
1926                 } else {
1927                   accum[idx] = 0;
1928                 }
1929               },
1930               [&](int accum_m) {});
1931         }
1932 
1933         auto attn_T = shared_storage.attn_shared_storage().accum_ref();
1934 #if 0
1935         PRINT_B0_T0("doivj_dropped");
1936         print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4);
1937         PRINT_TENSOR4x4_T0_L0("attn_T", attn_T)
1938 #endif
1939         accum_t current_di;
1940         // dSij = (dPij - Di) * Pij
1941         LambdaIterator::iterateRows(
1942             lane_offset,
1943             [&](int accum_m) { current_di = shared_storage.di()[accum_m]; },
1944             [&](int accum_m, int accum_n, int idx) {
1945               // TODO: Otherwise we can get nans as we
1946               // might have infs here (only seen on f16 tho)
1947               if (skipBoundsChecks ||
1948                   (accum_m < num_queries_in_block &&
1949                    accum_n < num_keys_in_block)) {
1950                 accum_t attn = attn_T.at({accum_n, accum_m});
1951                 accum[idx] = (accum[idx] - current_di) * attn;
1952               } else {
1953                 accum[idx] = 0;
1954               }
1955             },
1956             [&](int accum_m) {
1957 
1958             });
1959 
1960         // store bias gradient tile dBij to global memory,
1961         // where dBij = dSij = Pij * (dPij - Di)
1962         if (p.grad_bias_ptr != nullptr) {
1963           typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator
1964               output_iter(
1965                   typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator::
1966                       Params{p.gB_strideM},
1967                   // grad_bias_ptr is offset to point at beginning of
1968                   // matrix of shape (queries, keys) for a given
1969                   // (batch_id, head_id) the pointer arithmetic here produces
1970                   // a pointer to the start of the current tile within that
1971                   // matrix
1972                   p.grad_bias_ptr + query_start * p.gB_strideM + key_start,
1973                   {num_queries_in_block, num_keys_in_block},
1974                   thread_id);
1975 
1976           // no-op epilogue operator - just casting and storing contents of
1977           // accum to global memory
1978           typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
1979           typename MatmulDOIVJ::BiasGradEpilogue epilogue(
1980               shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
1981           epilogue(output_op, output_iter, accum, output_iter);
1982         }
1983 
1984         accum = accum * scale;
1985 
1986 #if 0
1987         PRINT_B0_T0("(doivj - di) * attn * scale");
1988         print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4);
1989 #endif
1990 
1991         __syncthreads();
1992         if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) {
1993           auto tmpT = shared_storage.tmpT_shared_storage().accum_ref();
1994           // attn <- attn_T.T
1995           LambdaIterator::iterateRows(
1996               lane_offset,
1997               [&](int accum_m) {},
1998               [&](int accum_m, int accum_n, int idx) {
1999                 tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]);
2000               },
2001               [&](int accum_m) {});
2002         }
2003       }
2004 
2005       MatmulDOIVJ::B2bGemm::accumToSmem(
2006           shared_storage.tmp_shared_storage(),
2007           accum,
2008           lane_id,
2009           output_tile_coords);
2010       __syncthreads();
2011     }
2012     // Force `nvcc` to recompute values that depend on the variables just below
2013     // to use less RF and prevent some spilling
2014     p.head_dim = warp_uniform(p.head_dim);
2015     p.k_strideM = warp_uniform(p.k_strideM);
2016     rematerializeThreadIds();
2017 
2018     /////////////////////////////////////////////////////////////////////////////////////////////////
2019     // GradQ matmul
2020     //
2021     // grad_q[i_start:i_end] += tmp @ k_j
2022     /////////////////////////////////////////////////////////////////////////////////////////////////
2023     // Skip the loop & associated branches if we know at compile time the number
2024     // of iterations
2025     constexpr bool kSingleIterationGradQ =
2026         kMaxK <= MatmulGradQ::ThreadblockShape::kN;
2027     for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim);
2028          col += MatmulGradQ::ThreadblockShape::kN) {
2029       using Mma = typename MatmulGradQ::Mma;
2030       using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
2031 
2032       cutlass::gemm::GemmCoord problem_size(
2033           num_queries_in_block,
2034           false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
2035           num_keys_in_block);
2036 
2037       // k_j
2038       typename Mma::IteratorB iterator_B(
2039           {int32_t(p.k_strideM)},
2040           const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM + col),
2041           {problem_size.k(), problem_size.n()},
2042           thread_id,
2043           no_offset);
2044 
2045       auto a = shared_storage.tmp_shared_storage().accum_ref();
2046       Mma mma(
2047           // operand A: dSij
2048           shared_storage.tmp_shared_storage().accum_ref(),
2049           // operand B: Kj
2050           shared_storage.mm_gradQ().operand_B_ref(),
2051           thread_id,
2052           warp_id,
2053           lane_id);
2054 
2055       typename Mma::FragmentC accum;
2056 
2057       int col_id = col / MatmulGradQ::ThreadblockShape::kN;
2058       int num_cols = kSingleIterationGradQ
2059           ? 1
2060           : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN);
2061       int storage_id = (col_id + query_start / kBlockSizeI * num_cols);
2062 
2063       if (p.num_splits_key_device() > 1) {
2064         AtomicLock::acquire(
2065             &p.workspace_gq[storage_id].lock,
2066             p.split_key_device() + 1,
2067             thread_id);
2068         // Make sure we can see other block's output
2069         __threadfence();
2070       }
2071 
2072       AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]};
2073       if (!kNeedsAccumGradQ ||
2074           (p.num_splits_key_device() == 1 && key_start == 0)) {
2075         // if we know we are the first to access it, we know it's only zeros.
2076         // Avoids a load from gmem (and gmem init as well)
2077         accum.clear();
2078       } else {
2079         gmem_tile.load(accum, thread_id);
2080       }
2081 
2082       auto gemm_k_iterations =
2083           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
2084 
2085       // Compute threadblock-scoped matrix multiply-add
2086       __syncthreads();
2087       mma.set_prologue_done(kPrologueGQ);
2088       mma(gemm_k_iterations, accum, iterator_B, accum);
2089       __syncthreads();
2090       bool isLastColumn = kSingleIterationGradQ ||
2091           (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim);
2092       if (kPrologueGQ && !isLastColumn) {
2093         prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN);
2094       }
2095 
2096       bool isLast = [&]() {
2097         int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ;
2098         if (p.num_keys <= next_key) {
2099           return true;
2100         }
2101         if (query_start < getSmallestQueryForKey(p, next_key)) {
2102           return true;
2103         }
2104         return false;
2105       }();
2106       // Output results
2107       if (p.num_splits_key_device() > 1) {
2108         int32_t numAddsSoFar = -1;
2109         if (isLast && thread_id == 0) {
2110           numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) +
2111               1; // `atomicAdd` returns the old value
2112         }
2113         isLast = __syncthreads_or(
2114             numAddsSoFar == getNumParallelBlocksForQuery(p, query_start));
2115         assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start));
2116       }
2117       if (kNeedsAccumGradQ && !isLast) {
2118         gmem_tile.store(accum, thread_id);
2119         if (p.num_splits_key_device() > 1) {
2120           // Make sure everyone wrote before we release the lock
2121           __threadfence();
2122           __syncthreads();
2123           AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id);
2124         }
2125       } else {
2126         // NOTE: We're not releasing the lock because no one is expected
2127         // to come after us (we're the last one to write)
2128         typename MatmulGradQ::OutputTileIterator output_it(
2129             typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
2130             p.grad_query_ptr + query_start * p.gQ_strideM() + col,
2131             {problem_size.m(), problem_size.n()},
2132             thread_id);
2133         // if `direct_store` is True, we store to gmem (`*gmem = accum`)
2134         // otherwise, we accumulate in gmem (`*gmem = *gmem + accum`)
2135         // If we know ahead of time when we will write for the first time
2136         // we can:
2137         // (1) Avoid an additional memory read
2138         // (2) Avoid the cost of initializing memory to 0
2139         bool direct_store = kNeedsAccumGradQ || key_start == 0 ||
2140             (p.num_splits_key_device() > 1);
2141         accumulateInGmem<MatmulGradQ>(
2142             isLastColumn ? shared_storage.gradQ_epilogue_lastIter()
2143                          : shared_storage.gradQ_epilogue(),
2144             accum,
2145             output_it,
2146             direct_store,
2147             warp_id,
2148             lane_id);
2149       }
2150     }
2151     /////////////////////////////////////////////////////////////////////////////////////////////////
2152     // GradK matmul
2153     //
2154     // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
2155     /////////////////////////////////////////////////////////////////////////////////////////////////
2156     rematerializeThreadIds();
2157 
2158     constexpr bool kSingleIterationGradK =
2159         kMaxK <= MatmulGradK::ThreadblockShape::kN;
2160     for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim);
2161          col += MatmulGradK::ThreadblockShape::kN) {
2162       using Mma = typename MatmulGradK::Mma;
2163       using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
2164 
2165       cutlass::gemm::GemmCoord problem_size(
2166           num_keys_in_block,
2167           false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col,
2168           num_queries_in_block);
2169       auto createEpilogueIter = [&]() {
2170         return typename MatmulGradK::OutputTileIterator(
2171             typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
2172             p.grad_key_ptr + key_start * p.gK_strideM() + col,
2173             {num_keys_in_block,
2174              false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
2175             thread_id);
2176       };
2177 
2178       // q_i
2179       typename Mma::IteratorB iterator_B(
2180           {int32_t(p.q_strideM)},
2181           const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM + col),
2182           {problem_size.k(), problem_size.n()},
2183           thread_id,
2184           no_offset);
2185 
2186       auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); };
2187       auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); };
2188       // this is basically:
2189       // opA = kIsTransposedA ? getTmp() : getTmpT();
2190       bool constexpr kIsTransposedA =
2191           MatmulGradK::DefaultMmaFromSmem::kIsTransposedA;
2192       auto& opA = *call_conditional<
2193           kIsTransposedA,
2194           decltype(getTmp),
2195           decltype(getTmpT)>::apply(getTmp, getTmpT, 0);
2196       Mma mma(
2197           // operand A: dSij.T
2198           opA.accum_ref(),
2199           // operand B: Qi
2200           shared_storage.mm_gradK().operand_B_ref(),
2201           thread_id,
2202           warp_id,
2203           lane_id);
2204 
2205       int storage_id = col / MatmulGradK::ThreadblockShape::kN;
2206       AccumTileGmem gmem_tile{
2207           p.workspace + storage_id * AccumTileGmem::kElementsStored};
2208       if (!kOutputInRF) {
2209         if (isFirstQuery || !kNeedsAccumGradK) {
2210           output_frags.gradK.clear();
2211         } else {
2212           gmem_tile.load(output_frags.gradK, thread_id);
2213         }
2214       }
2215       mma.set_prologue_done(kPrologueGK);
2216 
2217       auto gemm_k_iterations =
2218           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
2219 
2220       // Compute threadblock-scoped matrix multiply-add
2221       __syncthreads();
2222 
2223       mma(gemm_k_iterations,
2224           output_frags.gradK,
2225           iterator_B,
2226           output_frags.gradK);
2227       __syncthreads();
2228       bool isLastColumn = kSingleIterationGradK ||
2229           col + MatmulGradK::ThreadblockShape::kN >= p.head_dim;
2230       if (kPrologueGK && !isLastColumn) {
2231         prologueGradK(col + MatmulGradK::ThreadblockShape::kN);
2232       }
2233 
2234       if (kPrologueQK && isLastColumn) {
2235         int32_t next_query, next_key;
2236         incrIteration(p, query_start, key_start, next_query, next_key);
2237         DISPATCH_BOOL(
2238             next_key != key_start, kForceReloadK, ([&]() {
2239               prologueQkNextIteration<kForceReloadK>(
2240                   shared_storage, p, next_query, next_key, warp_id, lane_id);
2241             }));
2242       }
2243 
2244       // Output results
2245       if (!kOutputInRF) {
2246         if (kNeedsAccumGradK && !isLastQuery) {
2247           gmem_tile.store(output_frags.gradK, thread_id);
2248         } else {
2249           accumulateInGmem<MatmulGradK>(
2250               isLastColumn ? shared_storage.gradK_epilogue_final()
2251                            : shared_storage.gradK_epilogue(),
2252               output_frags.gradK,
2253               createEpilogueIter(),
2254               isFirstQuery || kNeedsAccumGradK,
2255               warp_id,
2256               lane_id);
2257           __syncthreads();
2258         }
2259       }
2260     }
2261   }
2262 
getQueryStartShiftAttentionBackwardKernel2263   static CUTLASS_HOST_DEVICE int32_t getQueryStartShift(Params const& p) {
2264     if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) {
2265       return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p);
2266     }
2267     return 0;
2268   }
2269 
2270   // Iteration order logic
2271   static CUTLASS_HOST_DEVICE int32_t
getQueryStartAttentionBackwardKernel2272   getQueryStart(Params const& p, int32_t key_start) {
2273     return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p);
2274   };
getQueryEndAttentionBackwardKernel2275   static CUTLASS_HOST_DEVICE int32_t getQueryEnd(Params const& p) {
2276     return align_up(p.num_queries, kBlockSizeI);
2277   };
2278 
2279   static CUTLASS_HOST_DEVICE int32_t
getSmallestQueryForKeyAttentionBackwardKernel2280   getSmallestQueryForKey(Params const& p, int32_t key_start) {
2281     if (p.custom_mask_type == NoCustomMask) {
2282       return 0;
2283     }
2284     int32_t shift = p.custom_mask_type == CausalFromBottomRight
2285         ? p.num_keys - p.num_queries
2286         : 0;
2287     int32_t window_size =
2288         p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size;
2289 
2290     auto last_key_for_block =
2291         cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1;
2292     int first_query = key_start - shift;
2293     int last_query = last_key_for_block - shift + window_size - 1;
2294     if (last_query < 0 || first_query >= p.num_queries) {
2295       return getQueryEnd(p); // nothing to compute in this column
2296     }
2297     first_query = cutlass::fast_max(0, first_query);
2298     return (first_query / kBlockSizeI) * kBlockSizeI;
2299   };
2300 
2301   // Returns how many kernel blocks will write to a given block in `grad_query`
2302   // This is usually equal to the number of key splits, but can be different
2303   // for instance in the causal case, or varying seqlen
2304   static CUTLASS_HOST_DEVICE int32_t
getNumParallelBlocksForQueryAttentionBackwardKernel2305   getNumParallelBlocksForQuery(Params const& p, int32_t query_start) {
2306     int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ);
2307     if (p.custom_mask_type != NoCustomMask) {
2308       int32_t shift = p.custom_mask_type == CausalFromBottomRight
2309           ? p.num_keys - p.num_queries
2310           : 0;
2311       int32_t last_query_for_block =
2312           cutlass::fast_min(query_start + kBlockSizeI, p.num_queries) - 1;
2313       int32_t last_key_for_block =
2314           cutlass::fast_min(last_query_for_block + shift, p.num_keys - 1);
2315       int32_t first_key_for_block = p.window_size == 0
2316           ? 0
2317           : cutlass::fast_max(query_start - p.window_size + 1 + shift, 0);
2318 
2319       if (p.window_size == 0) {
2320         num_key_blocks = last_key_for_block / kBlockSizeJ + 1;
2321       } else {
2322         num_key_blocks = (last_key_for_block / kBlockSizeJ) -
2323             (first_key_for_block / kBlockSizeJ) + 1;
2324       }
2325 
2326       if (last_key_for_block < 0 || first_key_for_block >= p.num_keys) {
2327         num_key_blocks = 0;
2328       }
2329     }
2330     return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks);
2331   };
2332 
2333   // Returns the next block to process
incrIterationAttentionBackwardKernel2334   static CUTLASS_HOST_DEVICE void incrIteration(
2335       Params const& p,
2336       int32_t query_start,
2337       int32_t key_start,
2338       int32_t& next_query,
2339       int32_t& next_key) {
2340     next_query = query_start + kBlockSizeI;
2341     next_key = key_start;
2342     auto query_shift = getQueryStartShift(p);
2343     // Wrap around
2344     if (query_shift) {
2345       if (next_query >= p.num_queries) {
2346         next_query = getSmallestQueryForKey(p, key_start);
2347         return;
2348       } else if (query_start < query_shift && query_shift <= next_query) {
2349         // jump to next key
2350       } else {
2351         return;
2352       }
2353     } else {
2354       if (p.window_size > 0) {
2355         int32_t shift = p.custom_mask_type == CausalFromBottomRight
2356             ? p.num_keys - p.num_queries
2357             : 0;
2358         // last key that is not masked out
2359         int last_key_for_block =
2360             cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1;
2361         int last_query = last_key_for_block - shift + p.window_size - 1;
2362         if (next_query <= last_query && next_query < p.num_queries) {
2363           return;
2364         }
2365       } else if (next_query < p.num_queries) {
2366         return;
2367       }
2368       // jump to next key
2369     }
2370     // Next key
2371     next_key = key_start + p.num_splits_key_device() * kBlockSizeJ;
2372     next_query = getQueryStart(p, next_key);
2373   }
2374 
2375   template <bool kForceReloadK>
prologueQkNextIterationAttentionBackwardKernel2376   static CUTLASS_DEVICE void prologueQkNextIteration(
2377       SharedStorage& shared_storage,
2378       Params const& p,
2379       int32_t query_start,
2380       int32_t key_start,
2381       uint8_t warp_id,
2382       uint8_t lane_id) {
2383     if (query_start >= p.num_queries || key_start >= p.num_keys) {
2384       return;
2385     }
2386 
2387     static constexpr bool kReloadK =
2388         kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
2389     int thread_id = 32 * warp_id + lane_id;
2390     typename MatmulQK::Mma::IteratorA iterator_A(
2391         {int32_t(p.k_strideM)},
2392         const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM),
2393         {p.num_keys - key_start, p.head_dim},
2394         thread_id,
2395         cutlass::MatrixCoord{0, 0});
2396 
2397     typename MatmulQK::Mma::IteratorB iterator_B(
2398         {int32_t(p.q_strideM)},
2399         const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM),
2400         {p.head_dim, p.num_queries - query_start},
2401         thread_id,
2402         cutlass::MatrixCoord{0, 0});
2403 
2404     MatmulQK::Mma::prologue<kReloadK, true>(
2405         shared_storage.mm_qk_k(),
2406         shared_storage.mm_qk_q(),
2407         iterator_A,
2408         iterator_B,
2409         thread_id,
2410         p.head_dim);
2411   }
2412 
2413   template <bool skipBoundsChecks>
writeFragsToGmemAttentionBackwardKernel2414   static CUTLASS_DEVICE void writeFragsToGmem(
2415       SharedStorage& shared_storage,
2416       OutputFragments& output_frags,
2417       Params const& p,
2418       int32_t key_start,
2419       uint8_t warp_id,
2420       uint8_t lane_id) {
2421     uint16_t thread_id = 32 * warp_id + lane_id;
2422     int32_t num_keys_in_block = skipBoundsChecks
2423         ? MatmulQK::Mma::Shape::kM
2424         : cutlass::fast_min(
2425               (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start);
2426     typename MatmulGradV::OutputTileIterator outputV_it(
2427         typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
2428         p.grad_value_ptr + key_start * p.gV_strideM(),
2429         {num_keys_in_block, p.head_dim_value},
2430         thread_id);
2431     accumulateInGmem<MatmulGradV>(
2432         shared_storage.gradV_epilogue_final(),
2433         output_frags.gradV,
2434         outputV_it,
2435         true,
2436         warp_id,
2437         lane_id);
2438 
2439     typename MatmulGradK::OutputTileIterator outputK_it(
2440         typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
2441         p.grad_key_ptr + key_start * p.gK_strideM(),
2442         {num_keys_in_block,
2443          false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
2444         thread_id);
2445     accumulateInGmem<MatmulGradK>(
2446         shared_storage.gradK_epilogue_final(),
2447         output_frags.gradK,
2448         outputK_it,
2449         true,
2450         warp_id,
2451         lane_id);
2452   }
2453 
2454   template <typename MatmulT>
accumulateInGmemAttentionBackwardKernel2455   static CUTLASS_DEVICE void accumulateInGmem(
2456       typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem,
2457       typename MatmulT::Mma::FragmentC const& accum,
2458       typename MatmulT::OutputTileIterator output_it,
2459       bool first,
2460       uint8_t warp_id,
2461       uint8_t lane_id) {
2462     using DefaultEpilogue = typename MatmulT::DefaultEpilogue;
2463     using DefaultOutputOp = typename MatmulT::DefaultOutputOp;
2464     using Mma = typename MatmulT::Mma;
2465     int thread_id = 32 * warp_id + lane_id;
2466     DISPATCH_BOOL(
2467         first, kIsFirst, ([&]() {
2468           static constexpr auto ScaleType = kIsFirst
2469               ? cutlass::epilogue::thread::ScaleType::Nothing
2470               : cutlass::epilogue::thread::ScaleType::NoBetaScaling;
2471           using EpilogueOutputOp =
2472               typename cutlass::epilogue::thread::LinearCombination<
2473                   typename DefaultOutputOp::ElementOutput,
2474                   DefaultOutputOp::kCount,
2475                   typename DefaultOutputOp::ElementAccumulator,
2476                   typename DefaultOutputOp::ElementCompute,
2477                   ScaleType>;
2478           using Epilogue =
2479               typename cutlass::epilogue::threadblock::EpiloguePipelined<
2480                   typename DefaultEpilogue::Shape,
2481                   typename Mma::Operator,
2482                   DefaultEpilogue::kPartitionsK,
2483                   typename MatmulT::OutputTileIterator,
2484                   typename DefaultEpilogue::AccumulatorFragmentIterator,
2485                   typename DefaultEpilogue::WarpTileIterator,
2486                   typename DefaultEpilogue::SharedLoadIterator,
2487                   EpilogueOutputOp,
2488                   typename DefaultEpilogue::Padding,
2489                   DefaultEpilogue::kFragmentsPerIteration,
2490                   true // IterationsUnroll
2491                   >;
2492           EpilogueOutputOp rescale({1, 1});
2493           Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id);
2494           epilogue(rescale, output_it, accum, output_it);
2495         }));
2496   }
2497 
2498   template <int kElementsPerAccess>
computeDeltaAttentionBackwardKernel2499   static CUTLASS_DEVICE void computeDelta(
2500       Params const& p,
2501       int32_t query_start,
2502       uint8_t warp_id,
2503       uint8_t lane_id) {
2504     // Each thread computes one value for Delta
2505     // Depending on warp configuration, we might have multiple
2506     // threads of the same warp working on the same row
2507     using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
2508     static_assert(kNumThreads >= kBlockSizeI, "");
2509     static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
2510     int16_t thread_id = 32 * warp_id + lane_id;
2511 
2512     int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine);
2513     int16_t laneRow = thread_id / kNumThreadsPerLine;
2514     bool rowPred = (query_start + laneRow) < p.num_queries;
2515     bool pred = rowPred;
2516 
2517     // on windows, previous syntax __restrict__ AccessType*
2518     // resulted in error: "restrict" is not allowed
2519     const AccessType* __restrict__ grad_output_ptr =
2520         reinterpret_cast<const AccessType*>(
2521             p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM +
2522             laneFirstCol);
2523     const AccessType* __restrict__ output_ptr =
2524         reinterpret_cast<const AccessType*>(
2525             p.output_ptr + (query_start + laneRow) * p.o_strideM() +
2526             laneFirstCol);
2527 
2528     static constexpr int64_t kMaxIters =
2529         kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
2530     constexpr int kPipelineStages = 2;
2531     accum_t delta_value = accum_t(0);
2532     using GlobalLoad =
2533         cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
2534     AccessType frag_grad_output[kPipelineStages];
2535     AccessType frag_output[kPipelineStages];
2536 
2537     auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
2538       frag_grad_output[ld_pos].clear();
2539       frag_output[ld_pos].clear();
2540       GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
2541       GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
2542       grad_output_ptr += kNumThreadsPerLine;
2543       output_ptr += kNumThreadsPerLine;
2544     };
2545 
2546     CUTLASS_PRAGMA_UNROLL
2547     for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
2548       int ld_pos = iter % kPipelineStages;
2549       pred = pred &&
2550           (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) <
2551               p.head_dim_value;
2552       loadAndIncrement(ld_pos, pred);
2553     }
2554     auto columnIteration = [&](int iter) {
2555       // Load for next iter
2556       int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
2557       pred = pred &&
2558           (laneFirstCol +
2559            (iter + kPipelineStages - 1) * kElementsPerAccess *
2560                kNumThreadsPerLine) < p.head_dim_value;
2561       loadAndIncrement(ld_pos, pred);
2562       CUTLASS_PRAGMA_UNROLL
2563       for (int i = 0; i < AccessType::kElements; ++i) {
2564         delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
2565             accum_t(frag_grad_output[iter % kPipelineStages][i]);
2566       }
2567     };
2568 
2569     // If we have a small lower-bound for K, we can unroll the loop
2570     if (kMaxK <= 256) {
2571       CUTLASS_PRAGMA_UNROLL
2572       for (int iter = 0; iter < kMaxIters; ++iter) {
2573         columnIteration(iter);
2574       }
2575     } else {
2576       int num_iters =
2577           ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
2578           (kElementsPerAccess * kNumThreadsPerLine);
2579       for (int iter = 0; iter < num_iters; ++iter) {
2580         columnIteration(iter);
2581       }
2582     }
2583 
2584     // Reduce between workers
2585     static_assert(
2586         kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 ||
2587             kNumThreadsPerLine == 4,
2588         "");
2589     CUTLASS_PRAGMA_UNROLL
2590     for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
2591       delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
2592     }
2593 
2594     // Store in gmem
2595     if (rowPred) {
2596       p.delta_ptr[query_start + laneRow] = delta_value;
2597     }
2598   }
2599 };
2600 
2601 template <typename AK>
__launch_bounds__(AK::kNumThreads,AK::kMinBlocksPerSm)2602 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
2603     attention_kernel_backward_batched_impl(typename AK::Params p) {
2604   if (!p.advance_to_block()) {
2605     return;
2606   }
2607   AK::attention_kernel(p);
2608 }
2609 
2610 template <typename AK>
2611 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
2612     attention_kernel_backward_batched(typename AK::Params params);
2613 
2614 } // namespace PyTorchMemEffAttention
2615