1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 #pragma once
9
10 #include <cmath>
11 #include <type_traits>
12 #include <vector>
13
14 #include <cuda_fp16.h>
15 #include <curand_kernel.h>
16
17 #include <ATen/cuda/PhiloxUtils.cuh>
18 #include <cutlass/cutlass.h>
19 #include <cutlass/epilogue/thread/linear_combination.h>
20 #include <cutlass/epilogue/thread/scale_type.h>
21 #include <cutlass/fast_math.h>
22 #include <cutlass/functional.h>
23 #include <cutlass/gemm/gemm.h>
24 #include <cutlass/layout/matrix.h>
25 #include <cutlass/layout/vector.h>
26 #include <cutlass/numeric_conversion.h>
27 #include <cutlass/numeric_types.h>
28 #include <cutlass/tensor_ref.h>
29
30 #include <cutlass/epilogue/thread/linear_combination_relu.h>
31 #include <cutlass/epilogue/threadblock/epilogue_smem_accumulator.h>
32 #include <cutlass/epilogue/warp/fragment_iterator_tensor_op.h>
33 #include <cutlass/epilogue/warp/tile_iterator_tensor_op.h>
34 #include <cutlass/gemm/device/default_gemm_configuration.h>
35 #include <cutlass/gemm/kernel/default_gemm.h>
36 #include <cutlass/gemm/threadblock/default_mma.h>
37 #include <cutlass/gemm/threadblock/default_mma_core_simt.h>
38 #include <cutlass/gemm/threadblock/default_mma_core_sm70.h>
39 #include <cutlass/gemm/threadblock/default_mma_core_sm75.h>
40 #include <cutlass/gemm/threadblock/default_mma_core_sm80.h>
41 #include <cutlass/integer_subbyte.h>
42 #include <cutlass/matrix_shape.h>
43 #include <cutlass/platform/platform.h>
44 #include <cutlass/transform/threadblock/predicated_tile_iterator.h>
45 #include <cutlass/transform/threadblock/vector_iterator.h>
46
47 #include <ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h>
48 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
49
50 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h>
51 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h>
52 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h>
53 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h>
54
55 #include <ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h>
56 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h>
57 #include <ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h>
58
59 #include <cinttypes>
60 #include <c10/util/Exception.h>
61
62 using namespace gemm_kernel_utils;
63
64 namespace PyTorchMemEffAttention {
65 namespace {
66
67 template <typename FragmentType, int32_t kNumThreads>
68 struct GmemTile {
69 /*
70 Helper functions to efficient store/load RF to gmem
71
72 GEMM accumulators have a particular format on A100, and
73 it takes some compute/shared-memory to rearrange them to
74 a RowMajor or ColumnMajor format in global memory through
75 an Epilogue. The same complexity goes for loading into RF.
76
77 This class loads/stores RF as they are, and can be used for
78 efficient accumulation across gemms for instance:
79
80 ```
81 GmemTile tile;
82 for (int i = 0; i < N; ++i) {
83 // ...
84
85 Fragment accum;
86 if (i == 0) {
87 accum.clear();
88 } else {
89 tile.load(accum);
90 }
91 mma(accum, ...);
92 if (i < N-1) {
93 // Store for next GEMM
94 tile.store(accum);
95 } else {
96 // Store in tensor (eg RowMajor)
97 epilogue(accum);
98 }
99
100 // ...
101 }
102 ```
103 */
104
105 // 128bits per thread
106 using AccessType = cutlass::Array<float, 4>;
107 static constexpr int32_t kBytes = sizeof(AccessType);
108 static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
109 static constexpr int32_t kNumIters =
110 FragmentType::kElements / AccessType::kElements;
111 static constexpr int32_t kElementsStored =
112 kNumThreads * FragmentType::kElements;
113 static_assert(
114 FragmentType::kElements % AccessType::kElements == 0,
115 "fragment not aligned on 128 bits");
116
117 float* ptr;
118
loadGmemTile119 CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) {
120 CUTLASS_PRAGMA_UNROLL
121 for (int i = 0; i < kNumIters; ++i) {
122 AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
123 ptr + thread_id * AccessType::kElements + i * kStride);
124 AccessType sub_fragment;
125 cutlass::arch::global_load<AccessType, kBytes>(
126 sub_fragment, gmem_ptr, true);
127 CUTLASS_PRAGMA_UNROLL
128 for (int j = 0; j < AccessType::kElements; ++j) {
129 fragment[i * AccessType::kElements + j] = sub_fragment[j];
130 }
131 }
132 }
133
storeGmemTile134 CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) {
135 CUTLASS_PRAGMA_UNROLL
136 for (int i = 0; i < kNumIters; ++i) {
137 AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
138 ptr + thread_id * AccessType::kElements + i * kStride);
139 AccessType sub_fragment;
140 CUTLASS_PRAGMA_UNROLL
141 for (int j = 0; j < AccessType::kElements; ++j) {
142 sub_fragment[j] = fragment[i * AccessType::kElements + j];
143 }
144 cutlass::arch::global_store<AccessType, kBytes>(
145 sub_fragment, gmem_ptr, true);
146 }
147 }
148
storeAtomicAddGmemTile149 CUTLASS_DEVICE void storeAtomicAdd(
150 FragmentType const& fragment,
151 int thread_id) {
152 CUTLASS_PRAGMA_UNROLL
153 for (int i = 0; i < kNumIters; ++i) {
154 float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride;
155 CUTLASS_PRAGMA_UNROLL
156 for (int j = 0; j < AccessType::kElements; ++j) {
157 float val = fragment[i * AccessType::kElements + j];
158 float* ptr = gmem_ptr + j;
159 atomicAdd(ptr, val);
160 }
161 }
162 }
163 };
164
165 struct AtomicLock {
acquireAtomicLock166 CUTLASS_DEVICE static void acquire(
167 int32_t* lock,
168 int set_val,
169 int thread_id) {
170 if (thread_id == 0) {
171 while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) {
172 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
173 __nanosleep(40);
174 #endif
175 }
176 }
177 __syncthreads();
178 }
releaseAtomicLock179 CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) {
180 if (thread_id == 0) {
181 int status = 0;
182 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
183 asm volatile("st.global.release.gpu.b32 [%0], %1;\n"
184 :
185 : "l"(lock), "r"(status));
186 #else
187 asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
188 #endif
189 }
190 }
191 };
192
193 template <typename scalar_t, typename Arch>
getWarpsPerSmBw()194 constexpr int getWarpsPerSmBw() {
195 bool is_half = !cutlass::platform::is_same<scalar_t, float>::value;
196 if (Arch::kMinComputeCapability >= 80) {
197 return is_half ? 12 : 8;
198 }
199 return 8;
200 }
201 } // namespace
202
203 template <
204 // which arch we target (eg `cutlass::arch::Sm80`)
205 typename ArchTag_,
206 // input/output type
207 typename scalar_t_,
208 // run optimized kernel because memory accesses will be aligned
209 bool kIsAligned_,
210 // use dropout if enabled
211 bool kApplyDropout_,
212 // when doing a GEMM, preload the next one (uses more shmem)
213 bool kPreload_,
214 // block dimensions
215 int kBlockSizeI_,
216 int kBlockSizeJ_,
217 // upperbound on `max(value.shape[-1], query.shape[-1])`
218 int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
219 // assumes that `cu_seqlen` is None, and
220 // (1) `num_queries % kBlockSizeI == 0`
221 // (2) `num_keys % kBlockSizeJ == 0`
222 bool kKeysQueriesAlignedToBlockSize_ = false>
223 struct AttentionBackwardKernel {
224 enum CustomMaskType {
225 NoCustomMask = 0,
226 CausalFromTopLeft = 1,
227 CausalFromBottomRight = 2,
228 NumCustomMaskTypes,
229 };
230 using scalar_t = scalar_t_;
231 using output_t = scalar_t;
232 using output_accum_t = float;
233 using lse_scalar_t = float;
234 using accum_t = float;
235 using ArchTag = ArchTag_;
236 static constexpr bool kIsAligned = kIsAligned_;
237 static constexpr bool kApplyDropout = kApplyDropout_;
238 static constexpr bool kPreload = kPreload_;
239 static constexpr int kBlockSizeI = kBlockSizeI_;
240 static constexpr int kBlockSizeJ = kBlockSizeJ_;
241 static constexpr int kMaxK = kMaxK_;
242 static constexpr bool kKeysQueriesAlignedToBlockSize =
243 kKeysQueriesAlignedToBlockSize_;
244
245 static constexpr int64_t kWarpSize = 32;
246
247 // If this is true, we store and accumulate dK/dV in RF
248 // rather than going back to gmem everytime
249 static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
250 static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
251 static_assert(
252 !kPreload ||
253 (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF),
254 "preload MMA not supported");
255 static constexpr bool kPrologueQK = kPreload;
256 static constexpr bool kPrologueGV = kPreload;
257 static constexpr bool kPrologueDOV = kPreload;
258 static constexpr bool kPrologueGQ = kPreload;
259 static constexpr bool kPrologueGK = kPreload;
260
261 static constexpr int64_t kNumWarpsPerBlock =
262 (kBlockSizeI * kBlockSizeJ) / (32 * 32);
263
264 // Compute delta for the f16 kernels
265 // TODO: Figure out why it's slower on the f32 kernels
266 // (something due to RF pressure?)
267 // TODO: Remove condition on `kOutputInRF` - this is needed to work
268 // around a compiler bug on V100, not exactly sure why but I spent
269 // too much time on this already. Reproducible with
270 // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
271 static constexpr bool kKernelComputesDelta =
272 kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);
273
274 // Launch bounds
275 static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
276 static constexpr int64_t kMinBlocksPerSm =
277 getWarpsPerSmBw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
278
279 using GemmType = DefaultGemmType<ArchTag, scalar_t>;
280 using DefaultConfig =
281 typename cutlass::gemm::device::DefaultGemmConfiguration<
282 typename GemmType::OpClass,
283 ArchTag,
284 scalar_t,
285 scalar_t,
286 scalar_t, // ElementC
287 accum_t // ElementAccumulator
288 >;
289 static constexpr auto kOptimalAlignement = cutlass::platform::max(
290 DefaultConfig::kAlignmentA,
291 DefaultConfig::kAlignmentB);
292 static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment;
293
294 struct MatmulQK {
295 /*
296 attn_T = k_j @ q_i.transpose(-2, -1) # matmul
297 attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2,
298 -1)).exp() # epilogue
299
300 with attn_T.shape = (kBlockSizeJ, kBlockSizeI)
301 */
302 using ThreadblockShape =
303 cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
304 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
305 using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma<
306 scalar_t, // ElementA
307 cutlass::layout::RowMajor, // LayoutA
308 kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
309 scalar_t, // ElementB
310 cutlass::layout::ColumnMajor, // LayoutB
311 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
312 accum_t, // ElementC
313 cutlass::layout::RowMajor, // LayoutC
314 typename GemmType::OpClass,
315 ArchTag,
316 ThreadblockShape,
317 WarpShape,
318 typename GemmType::InstructionShape,
319 DefaultConfig::kStages,
320 typename GemmType::Operator,
321 false, // AccumulatorsInRowMajor = false,
322 cutlass::gemm::SharedMemoryClearOption::kNone>;
323 using MmaCore = typename DefaultMma::MmaCore;
324 using Mma =
325 typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma;
326
327 // used for efficient load of bias tile (Bij) from global memory to shared
328 // memory
329 using BiasLoader = TileSmemLoader<
330 scalar_t,
331 // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
332 // row-major but needs to have transposed shape so we get the same
333 // elements.
334 cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>,
335 MmaCore::kThreads,
336 // input restriction: kv_len has to be a multiple of this value
337 128 / cutlass::sizeof_bits<scalar_t>::value>;
338
339 // Epilogue to store to shared-memory in a format that we can use later for
340 // the second matmul
341 using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
342 typename Mma::Operator::IteratorC,
343 typename Mma::Operator,
344 scalar_t,
345 WarpShape,
346 ThreadblockShape>;
347 using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
348 typename Mma::Operator::IteratorC,
349 accum_t,
350 kWarpSize>::Iterator;
351 using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
352 };
353
354 struct MatmulGradV {
355 /*
356 grad_v[j_start:j_end] += attn_T @ do_i # matmul
357
358 Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K)
359 (we might need to iterate multiple times on K)
360 */
361 using ThreadblockShape =
362 cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
363 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
364 using InstructionShape = typename GemmType::InstructionShape;
365
366 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
367 scalar_t, // ElementA,
368 cutlass::layout::RowMajor, // LayoutA,
369 DefaultConfig::kAlignmentA,
370 scalar_t, // ElementB,
371 cutlass::layout::RowMajor, // LayoutB,
372 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
373 output_t,
374 cutlass::layout::RowMajor, // LayoutC,
375 accum_t,
376 typename GemmType::OpClass,
377 ArchTag,
378 ThreadblockShape,
379 WarpShape,
380 typename GemmType::InstructionShape,
381 typename DefaultConfig::EpilogueOutputOp,
382 void, // ThreadblockSwizzle - not used
383 DefaultConfig::kStages,
384 false, // SplitKSerial
385 typename GemmType::Operator>;
386
387 // if dropout:
388 // for computing dVj += (Pij.T * Zij) @ dOi
389 // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
390 // Pij.T are loaded in. The reason we do it this way is because Pij.T and
391 // Zij are reused in later steps, while Pij_dropped.T is only needed in
392 // this step. computing Pij_dropped.T on the fly allows us to avoid
393 // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
394 // same time.
395 // if no dropout:
396 // for computing dVj += Pij.T @ dOi
397 using WarpIteratorA = typename cutlass::gemm::threadblock::
398 DefaultWarpIteratorAFromSharedMemory<
399 typename DefaultGemm::Mma::Operator::Shape, // WarpShape
400 typename DefaultGemm::Mma::Operator::
401 InstructionShape, // InstructionShape
402 typename DefaultGemm::Mma::Operator::
403 IteratorA, // RegularWarpIterator
404 typename DefaultGemm::Mma::Policy // Policy
405 >::WarpIterator;
406 using DefaultMmaFromSmem =
407 typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
408 typename DefaultGemm::Mma,
409 MatmulQK::AccumulatorSharedStorage::Shape::kN,
410 WarpIteratorA,
411 kApplyDropout>; // kScaleOperandA
412
413 using Mma = typename DefaultMmaFromSmem::Mma;
414 using IteratorB = typename Mma::IteratorB;
415 using WarpCount = typename Mma::WarpCount;
416
417 // Epilogue
418 using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
419 using DefaultEpilogue = typename DefaultGemm::Epilogue;
420 using OutputTileIterator =
421 typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
422 typename DefaultEpilogue::OutputTileIterator>::Iterator;
423 using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
424 };
425
426 struct MatmulDOIVJ {
427 /*
428 doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul
429 tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue?
430 */
431 using ThreadblockShape =
432 cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
433 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
434
435 using ElementC = output_t;
436 using ElementAccum = accum_t;
437
438 // no-op output op - epilogue just stores result to global memory
439 using BiasGradEpilogueOutputOp =
440 typename cutlass::epilogue::thread::LinearCombination<
441 ElementC,
442 DefaultConfig::EpilogueOutputOp::kCount,
443 typename DefaultConfig::EpilogueOutputOp::ElementAccumulator,
444 typename DefaultConfig::EpilogueOutputOp::ElementCompute,
445 cutlass::epilogue::thread::ScaleType::Nothing>;
446
447 using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
448 scalar_t, // ElementA
449 cutlass::layout::RowMajor, // LayoutA
450 kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
451 scalar_t, // ElementB
452 cutlass::layout::ColumnMajor, // LayoutB
453 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
454 ElementC, // ElementC
455 cutlass::layout::RowMajor, // LayoutC
456 ElementAccum, // ElementAccumulator
457 typename GemmType::OpClass,
458 ArchTag,
459 ThreadblockShape,
460 WarpShape,
461 typename GemmType::InstructionShape,
462 BiasGradEpilogueOutputOp, // EpilogueOutputOp
463 void, // ThreadblockSwizzle (not used)
464 // multiple preloads, dropout Zij tile, and 3 stages push us over shared
465 // memory capacity on A100. set a ceiling on number of stages to save
466 // shared memory if dropout is in use.
467 kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64)
468 ? cutlass::const_min(2, DefaultConfig::kStages)
469 : DefaultConfig::kStages, // Stages
470 false, // SplitKSerial
471 typename GemmType::Operator,
472 cutlass::gemm::SharedMemoryClearOption::kNone>;
473 using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
474 using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
475 typename Mma::Operator::IteratorC,
476 ElementAccum,
477 kWarpSize>::Iterator;
478
479 // epilogue used to write bias gradient, which is just the output of this
480 // matmul with some operations applied to the fragment
481 using BiasGradEpilogue = typename DefaultGemm::Epilogue;
482
483 // Epilogue to store to shared-memory in a format that we can use later for
484 // the second matmul
485 using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
486 typename DefaultGemm::Mma::Operator::IteratorC,
487 typename DefaultGemm::Mma::Operator,
488 scalar_t,
489 WarpShape,
490 ThreadblockShape>;
491 using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
492 };
493
494 struct MatmulGradQ {
495 // grad_q <- tmp @ k_j
496 using ThreadblockShape =
497 cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
498 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
499 using InstructionShape = typename GemmType::InstructionShape;
500
501 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
502 scalar_t, // ElementA,
503 cutlass::layout::RowMajor, // LayoutA,
504 DefaultConfig::kAlignmentA,
505 scalar_t, // ElementB,
506 cutlass::layout::RowMajor, // LayoutB,
507 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
508 output_t,
509 cutlass::layout::RowMajor, // LayoutC,
510 accum_t,
511 typename GemmType::OpClass,
512 ArchTag,
513 ThreadblockShape,
514 WarpShape,
515 typename GemmType::InstructionShape,
516 typename DefaultConfig::EpilogueOutputOp,
517 void, // ThreadblockSwizzle - not used
518 DefaultConfig::kStages,
519 false, // SplitKSerial
520 typename GemmType::Operator>;
521
522 using WarpIteratorA = typename cutlass::gemm::threadblock::
523 DefaultWarpIteratorAFromSharedMemory<
524 typename DefaultGemm::Mma::Operator::Shape,
525 typename DefaultGemm::Mma::Operator::InstructionShape,
526 typename DefaultGemm::Mma::Operator::IteratorA,
527 typename DefaultGemm::Mma::Policy>::WarpIterator;
528 using DefaultMmaFromSmem =
529 typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
530 typename DefaultGemm::Mma,
531 MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN,
532 WarpIteratorA,
533 false>; // kScaleOperandA
534 using Mma = typename DefaultMmaFromSmem::Mma;
535 using IteratorB = typename Mma::IteratorB;
536 using WarpCount = typename Mma::WarpCount;
537
538 // Epilogue
539 using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
540 using DefaultEpilogue = typename DefaultGemm::Epilogue;
541 using OutputTileIterator =
542 typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
543 typename DefaultEpilogue::OutputTileIterator>::Iterator;
544 using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
545 };
546 struct MatmulGradK {
547 // grad_k <- tmp.transpose(-2, -1) @ q_i
548 using ThreadblockShape =
549 cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
550 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
551 using InstructionShape = typename GemmType::InstructionShape;
552
553 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
554 scalar_t, // ElementA,
555 cutlass::layout::RowMajor, // LayoutA,
556 DefaultConfig::kAlignmentA,
557 scalar_t, // ElementB,
558 cutlass::layout::RowMajor, // LayoutB,
559 kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
560 output_t,
561 cutlass::layout::RowMajor, // LayoutC,
562 accum_t,
563 typename GemmType::OpClass,
564 ArchTag,
565 ThreadblockShape,
566 WarpShape,
567 typename GemmType::InstructionShape,
568 typename DefaultConfig::EpilogueOutputOp,
569 void, // ThreadblockSwizzle - not used
570 DefaultConfig::kStages,
571 false, // SplitKSerial
572 typename GemmType::Operator>;
573
574 using WarpIteratorA = typename cutlass::gemm::threadblock::
575 DefaultWarpIteratorAFromSharedMemory<
576 typename DefaultGemm::Mma::Operator::Shape,
577 typename DefaultGemm::Mma::Operator::InstructionShape,
578 typename DefaultGemm::Mma::Operator::IteratorA,
579 typename DefaultGemm::Mma::Policy>::WarpIterator;
580 using DefaultMmaFromSmemN =
581 typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
582 typename DefaultGemm::Mma,
583 MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK
584 WarpIteratorA,
585 false>; // kScaleOperandA
586 using DefaultMmaFromSmemT =
587 typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
588 typename DefaultGemm::Mma,
589 MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK
590 WarpIteratorA,
591 false, // kScaleOperandA
592 kPreload>; // kTransposeA
593 using DefaultMmaFromSmem = typename cutlass::platform::conditional<
594 DefaultMmaFromSmemT::kIsTransposedA,
595 DefaultMmaFromSmemT,
596 DefaultMmaFromSmemN>::type;
597 using Mma = typename DefaultMmaFromSmem::Mma;
598 using IteratorB = typename Mma::IteratorB;
599 using WarpCount = typename Mma::WarpCount;
600
601 // Epilogue
602 using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
603 using DefaultEpilogue = typename DefaultGemm::Epilogue;
604 using OutputTileIterator =
605 typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
606 typename DefaultEpilogue::OutputTileIterator>::Iterator;
607 using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
608 };
609
610 // NOTE: nvcc 12.4 has correctness errors with this on M60 (sm52)
611 // when there is an attention bias. Let's just disable it for now.
612 static constexpr auto kMinSm = ArchTag::kMinComputeCapability;
613 static constexpr bool kEnableSplitKeys = kMinSm >= 70;
614
615 static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys ||
616 !cutlass::platform::is_same<output_accum_t, output_t>::value;
617 static constexpr bool kNeedsAccumGradK = !kOutputInRF &&
618 !cutlass::platform::is_same<output_accum_t, output_t>::value;
619 static constexpr bool kNeedsAccumGradV = !kOutputInRF &&
620 !cutlass::platform::is_same<output_accum_t, output_t>::value;
621
622 struct GradQTempStorage {
623 int32_t lock;
624 int32_t counter;
625 int32_t pad[2]; // pad to 128bits
626 output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored];
627 };
628
629 struct Params {
630 // Input tensors
631 const scalar_t* query_ptr = nullptr; // [Mq, nH, K]
632 const scalar_t* key_ptr = nullptr; // [Mk, nH, K]
633 const scalar_t* value_ptr = nullptr; // [Mk, nH, Kv]
634 const scalar_t* bias_ptr = nullptr;
635 const lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq]
636 const scalar_t* output_ptr = nullptr; // [Mq, nH, Kv]
637 const scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv]
638 accum_t* delta_ptr = nullptr; // [nH, Mq]
639 const int32_t* cu_seqlens_q_ptr = nullptr;
640 const int32_t* cu_seqlens_k_ptr = nullptr;
641
642 // Output tensors
643 output_t* grad_query_ptr = nullptr; // [Mq, nH, K]
644 output_t* grad_key_ptr = nullptr; // [Mk, nH, K]
645 output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv]
646 output_t* grad_bias_ptr = nullptr;
647
648 // Accumulators
649 output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
650 output_accum_t* workspace_gv =
651 nullptr; // (will be calculated by the kernel)
652 GradQTempStorage* workspace_gq =
653 nullptr; // (will be calculated by the kernel)
654
655 // Sliding window. ignored if == 0
656 int32_t window_size = 0;
657
658 // Scale
659 accum_t scale = 1.0f;
660
661 // Dimensions/strides
662 int32_t head_dim = -1;
663 int32_t head_dim_value = -1;
664 int32_t num_queries = -1;
665 int32_t num_keys = -1;
666 int32_t num_heads = -1;
667 uint8_t custom_mask_type = NoCustomMask;
668
669 int32_t q_strideM = -1;
670 int32_t k_strideM = -1;
671 int32_t v_strideM = -1;
672 int32_t bias_strideM = 0;
673 int32_t gO_strideM = -1;
674 int32_t gB_strideM = -1;
675 int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise
676
677 at::PhiloxCudaState rng_engine_inputs = {0, 0};
678
679 // RNG sequence offset based on batch_id and head_id
680 unsigned long long dropout_batch_head_rng_offset = 0;
681 float dropout_prob = 0.0f;
682
o_strideMAttentionBackwardKernel::Params683 CUTLASS_HOST_DEVICE int32_t o_strideM() const {
684 return head_dim_value * num_heads;
685 }
gQ_strideMAttentionBackwardKernel::Params686 CUTLASS_HOST_DEVICE int32_t gQ_strideM() const {
687 return gQKV_strideM_multiplier * num_heads * head_dim;
688 }
gK_strideMAttentionBackwardKernel::Params689 CUTLASS_HOST_DEVICE int32_t gK_strideM() const {
690 return gQKV_strideM_multiplier * num_heads * head_dim;
691 }
gV_strideMAttentionBackwardKernel::Params692 CUTLASS_HOST_DEVICE int32_t gV_strideM() const {
693 return gQKV_strideM_multiplier * num_heads * head_dim_value;
694 }
695
696 // Everything below is only used in `advance_to_block`
697 // and shouldn't use registers
698 int64_t o_strideH = -1;
699 int32_t q_strideH = -1;
700 int32_t k_strideH = -1;
701 int32_t v_strideH = -1;
702 int64_t bias_strideH = 0;
703 int64_t o_strideB = -1;
704 int64_t q_strideB = -1;
705 int64_t k_strideB = -1;
706 int64_t v_strideB = -1;
707 int64_t bias_strideB = 0;
708 int64_t lse_strideB = -1;
709 int64_t lse_strideH = -1;
710 int64_t delta_strideB = -1;
711 int64_t delta_strideH = -1;
712 int32_t num_batches = -1;
713 int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel
714
715 int64_t gO_strideB = 0;
716 int64_t gQ_strideB = 0;
717 int64_t gK_strideB = 0;
718 int64_t gV_strideB = 0;
719 int64_t gB_strideB = 0;
720 int64_t gO_strideH = 0;
721 int64_t gQ_strideH = 0;
722 int64_t gK_strideH = 0;
723 int64_t gV_strideH = 0;
724 int64_t gB_strideH = 0;
725
num_splits_key_deviceAttentionBackwardKernel::Params726 CUTLASS_HOST_DEVICE int16_t num_splits_key_device() const {
727 #ifdef __CUDA_ARCH__
728 return kEnableSplitKeys ? gridDim.x : 1;
729 #else
730 return num_splits_key; // for host-side tests
731 #endif
732 }
split_key_deviceAttentionBackwardKernel::Params733 CUTLASS_HOST_DEVICE int16_t split_key_device() const {
734 #ifdef __CUDA_ARCH__
735 return kEnableSplitKeys ? blockIdx.x : 0;
736 #else
737 return 0; // for host-side tests
738 #endif
739 }
740
advance_to_blockAttentionBackwardKernel::Params741 CUTLASS_DEVICE bool advance_to_block() {
742 int64_t batch_id = blockIdx.z;
743 int32_t head_id = blockIdx.y;
744
745 if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
746 assert(workspace_size() == 0 || workspace != nullptr);
747
748 workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
749 workspace = warp_uniform(workspace);
750 workspace_gv = workspace + workspace_elements_gk();
751 workspace_gq =
752 (GradQTempStorage*)(workspace_gv + workspace_elements_gv());
753 if (kEnableSplitKeys) {
754 workspace_gv += workspace_elements_gv() * split_key_device() /
755 num_splits_key_device();
756 workspace += workspace_elements_gk() * split_key_device() /
757 num_splits_key_device();
758 }
759 } else {
760 workspace = nullptr;
761 }
762
763 // Advance pointers that depend on the total concatenated
764 // number of queries, as `num_queries` is modified in the block
765 // below
766 dropout_batch_head_rng_offset =
767 batch_id * (num_heads * num_queries * num_keys) +
768 head_id * (num_queries * num_keys);
769 logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH;
770
771 if (cu_seqlens_q_ptr != nullptr) {
772 assert(cu_seqlens_k_ptr != nullptr);
773 cu_seqlens_q_ptr += batch_id;
774 cu_seqlens_k_ptr += batch_id;
775 int32_t q_start = cu_seqlens_q_ptr[0];
776 int32_t k_start = cu_seqlens_k_ptr[0];
777 int64_t q_next_start = cu_seqlens_q_ptr[1];
778 int64_t k_next_start = cu_seqlens_k_ptr[1];
779 assert(q_next_start - q_start <= num_queries);
780 assert(k_next_start - k_start <= num_keys);
781 num_queries = q_next_start - q_start;
782 num_keys = k_next_start - k_start;
783
784 // Jump manually
785 batch_id = 0;
786
787 query_ptr += q_start * q_strideM;
788 key_ptr += k_start * k_strideM;
789 value_ptr += k_start * v_strideM;
790 assert(bias_ptr == nullptr);
791 assert(grad_bias_ptr == nullptr);
792 output_ptr += q_start * o_strideM();
793 grad_output_ptr += q_start * gO_strideM;
794 delta_ptr += q_start;
795
796 grad_query_ptr += q_start * gQ_strideM();
797 grad_key_ptr += k_start * gK_strideM();
798 grad_value_ptr += k_start * gV_strideM();
799 }
800
801 query_ptr += batch_id * q_strideB + head_id * q_strideH;
802 key_ptr += batch_id * k_strideB + head_id * k_strideH;
803 value_ptr += batch_id * v_strideB + head_id * v_strideH;
804 if (bias_ptr != nullptr) {
805 bias_ptr += batch_id * bias_strideB + head_id * bias_strideH;
806 }
807 output_ptr += batch_id * o_strideB + head_id * o_strideH;
808 grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
809 delta_ptr += batch_id * delta_strideB + head_id * delta_strideH;
810
811 grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH;
812 grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH;
813 grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH;
814 if (grad_bias_ptr != nullptr) {
815 grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH;
816 }
817
818 // Some values are modified above
819 // Signal to the compiler that they are the same in all threads
820 // and can be stored in warp-uniform registers (Sm75+)
821 num_queries = warp_uniform(num_queries);
822 num_keys = warp_uniform(num_keys);
823 custom_mask_type = warp_uniform(custom_mask_type);
824
825 query_ptr = warp_uniform(query_ptr);
826 key_ptr = warp_uniform(key_ptr);
827 value_ptr = warp_uniform(value_ptr);
828 bias_ptr = warp_uniform(bias_ptr);
829 logsumexp_ptr = warp_uniform(logsumexp_ptr);
830 output_ptr = warp_uniform(output_ptr);
831 grad_output_ptr = warp_uniform(grad_output_ptr);
832 delta_ptr = warp_uniform(delta_ptr);
833
834 grad_query_ptr = warp_uniform(grad_query_ptr);
835 grad_key_ptr = warp_uniform(grad_key_ptr);
836 grad_value_ptr = warp_uniform(grad_value_ptr);
837 grad_bias_ptr = warp_uniform(grad_bias_ptr);
838
839 #if 0
840 PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f",
841 int(blockIdx.z), int(blockIdx.y),
842 float(delta_ptr[0]),
843 float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]),
844 float(logsumexp_ptr[0])
845 )
846 #endif
847 return true;
848 }
849
getBlocksGridAttentionBackwardKernel::Params850 __host__ dim3 getBlocksGrid() const {
851 return dim3(num_splits_key, num_heads, num_batches);
852 }
getThreadsGridAttentionBackwardKernel::Params853 __host__ dim3 getThreadsGrid() const {
854 return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1);
855 }
workspace_elements_gkAttentionBackwardKernel::Params856 CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const {
857 if (!kNeedsAccumGradK) {
858 return 0;
859 }
860 return num_splits_key * kBlockSizeJ *
861 align_up(head_dim, (int32_t)kBlockSizeI);
862 }
workspace_elements_gvAttentionBackwardKernel::Params863 CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const {
864 if (!kNeedsAccumGradV) {
865 return 0;
866 }
867 return num_splits_key * kBlockSizeJ *
868 align_up(head_dim_value, (int32_t)kBlockSizeI);
869 }
workspace_elements_gqAttentionBackwardKernel::Params870 CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const {
871 if (!kNeedsAccumGradQ) {
872 return 0;
873 }
874 int num_blocks = ceil_div(num_queries, kBlockSizeI);
875 int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN);
876 return num_blocks * num_cols * sizeof(GradQTempStorage) /
877 sizeof(output_accum_t);
878 }
workspace_strideBHAttentionBackwardKernel::Params879 CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const {
880 // Aligned on 128bits
881 return align_up(
882 workspace_elements_gk() + workspace_elements_gv() +
883 workspace_elements_gq(),
884 int64_t(4));
885 }
workspace_sizeAttentionBackwardKernel::Params886 CUTLASS_HOST_DEVICE int64_t workspace_size() const {
887 // Returns size of buffer we need to run this kernel
888 return num_batches * num_heads * workspace_strideBH() * sizeof(float);
889 }
should_zero_workspaceAttentionBackwardKernel::Params890 CUTLASS_HOST_DEVICE bool should_zero_workspace() const {
891 return num_splits_key > 1 || window_size > 0;
892 }
893 };
894
895 // shared storage for keeping Zij matrix. not needed if we aren't using
896 // dropout, in which case we use an empty array to save shared memory
897 using ZijSharedStorage = typename cutlass::platform::conditional<
898 kApplyDropout,
899 typename MatmulQK::AccumulatorSharedStorage,
900 // dummy shared storage object that takes up no space.
901 typename cutlass::gemm::threadblock::AccumulatorSharedStorage<
902 #ifdef _WIN32
903 // windows builds throw the error:
904 // "type containing an unknown-size array is not allowed"
905 // if we try to make Zij shared storage zero-sized.
906 // To get around this just make it sized 1 on windows.
907 typename cutlass::gemm::GemmShape<1, 1, 0>,
908 #else
909 typename cutlass::gemm::GemmShape<0, 0, 0>,
910 #endif
911 typename MatmulQK::AccumulatorSharedStorage::Element,
912 typename MatmulQK::AccumulatorSharedStorage::Layout,
913 typename cutlass::MatrixShape<0, 0>>>::type;
914
915 struct SharedStoragePrologue {
916 struct {
917 cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
918 typename MatmulQK::Mma::SharedStorageA mm_qk_k;
919 } persistent;
920 union {
921 struct {
922 // part1 - after Q.K / dV / dO.V
923 union {
924 // 1. efficient load of bias tile Bij, which is then applied to Pij
925 typename MatmulQK::BiasLoader::SmemTile bias;
926 // 4. store Pij. it is needed:
927 // - in dVj += (Pij.T * Zij) @ dOi
928 // - in dSij = Pij * (dPij - Di)
929 // 6. dVj += (Pij.T * Zij) @ dOi
930 // 10. write to fragment
931 typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
932 };
933 // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi
934 ZijSharedStorage zij;
935
936 union {
937 // 2. prologue for dVj
938 // 6. workspace for dVj += (Pij.T * Zij) @ dOi
939 typename MatmulGradV::Mma::SharedStorage mm_gradV;
940 // 7. dVj epilogue
941 typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
942 };
943
944 // 3. prologue for dPij_dropped
945 // 8. used in dPij_dropped = dOi @ Vj.T
946 typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
947 } part1;
948
949 struct {
950 // part2 - dQ
951 union {
952 typename MatmulQK::AccumulatorSharedStorage
953 tmpT_shared_storage; // (from part1)
954 typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
955 };
956 typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
957 typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
958 union {
959 // store dB = dSij to global memory
960 typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
961 typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
962 };
963
964 } part2;
965
966 struct {
967 // part3 - after last iteration on dQ's epilogue / dK
968 union {
969 typename MatmulQK::AccumulatorSharedStorage
970 tmpT_shared_storage; // (from part1)
971 typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
972 };
973 typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
974 typename MatmulGradQ::DefaultEpilogue::SharedStorage
975 gradQ_epilogue_lastIter;
976
977 typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
978 } part3;
979
980 struct {
981 // part4 - after last iteration on dK's epilogue / preload next K.Q_t
982 typename MatmulQK::Mma::SharedStorageB mm_qk_q;
983
984 // If we reach end of current key, dump RF->gmem with "final" epilogues
985 typename MatmulGradK::DefaultEpilogue::SharedStorage
986 gradK_epilogue_final;
987 typename MatmulGradV::DefaultEpilogue::SharedStorage
988 gradV_epilogue_final;
989 } part4;
990 };
print_sizeAttentionBackwardKernel::SharedStoragePrologue991 static void print_size() {
992 // Field size
993 #define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f)))
994
995 printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue)));
996 printf(" persistent: %db\n", FSZ(persistent));
997 printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k));
998 printf(" part1: %db\n", FSZ(part1));
999 printf(" bias: %db\n", FSZ(part1.bias));
1000 printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage));
1001 printf(" zij: %db\n", FSZ(part1.zij));
1002 printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV));
1003 printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue));
1004 printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj));
1005 printf(" part2: %db\n", FSZ(part2));
1006 printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage));
1007 printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage));
1008 printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK));
1009 printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ));
1010 printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue));
1011 printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue));
1012 printf(" part3: %db\n", FSZ(part3));
1013 printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage));
1014 printf(" part4: %db\n", FSZ(part4));
1015 printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q));
1016 printf(
1017 " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final));
1018 printf(
1019 " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final));
1020 }
1021 // ===========================================
1022 #define FIELD(INSIDE_STRUCT, FIELDNAME) \
1023 CUTLASS_DEVICE auto& FIELDNAME() { \
1024 return INSIDE_STRUCT.FIELDNAME; \
1025 }
1026
1027 FIELD(persistent, di)
1028 FIELD(persistent, mm_qk_k)
1029 FIELD(part1, bias)
1030 FIELD(part1, attn_shared_storage)
1031 FIELD(part1, zij)
1032 FIELD(part1, mm_gradV)
1033 FIELD(part1, gradV_epilogue)
1034 FIELD(part1, mm_doivj)
1035 FIELD(part2, mm_gradK)
1036 FIELD(part2, mm_gradQ)
1037 FIELD(part2, gradB_epilogue)
1038 FIELD(part2, gradQ_epilogue)
1039 FIELD(part2, tmp_shared_storage)
1040 FIELD(part3, tmpT_shared_storage)
1041 FIELD(part3, gradQ_epilogue_lastIter)
1042 FIELD(part3, gradK_epilogue)
1043 FIELD(part4, mm_qk_q)
1044 FIELD(part4, gradK_epilogue_final)
1045 FIELD(part4, gradV_epilogue_final)
1046 };
1047
1048 struct SharedStorageNoPrologue {
1049 struct {
1050 cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
1051 } persistent;
1052 union {
1053 struct {
1054 // part1 - Q.K matmul
1055 typename MatmulQK::Mma::SharedStorageA mm_qk_k;
1056 typename MatmulQK::Mma::SharedStorageB mm_qk_q;
1057 } part1;
1058
1059 struct {
1060 // part2 - compute gradV
1061 union {
1062 // 1. efficient load of bias tile Bij, which is then applied to Pij
1063 typename MatmulQK::BiasLoader::SmemTile bias;
1064 // 2. store Pij to shared memory. it is needed:
1065 // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
1066 // - in next step where it is used in dSij = Pij * (dPij - Di)
1067 typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
1068 };
1069 // 3. store Zij. it is needed in this step, where it is used
1070 // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are
1071 // loaded for the computation of dVj.
1072 ZijSharedStorage zij;
1073
1074 union {
1075 typename MatmulGradV::Mma::SharedStorage mm_gradV;
1076 typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
1077 };
1078 } part2;
1079
1080 struct {
1081 // part3 - DO.V matmul
1082 union {
1083 // first compute dPij = (dOi @ Vj.T) * Zij
1084 // and dSij = Pij * (dPij - Di)
1085 struct {
1086 // (from part2) - Pij for computing dSij = Pij * (dPij - Di)
1087 typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
1088 // matmul to compute dOiVj
1089 typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
1090 };
1091 // then store dB = dSij to global memory
1092 typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
1093 };
1094 } part3;
1095
1096 struct {
1097 // part4 - compute gradQ
1098 typename MatmulQK::AccumulatorSharedStorage
1099 tmpT_shared_storage; // (from part2)
1100 typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
1101 union {
1102 typename MatmulGradQ::Mma::SharedStorage mm_gradQ;
1103 typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
1104 typename MatmulGradQ::DefaultEpilogue::SharedStorage
1105 gradQ_epilogue_lastIter;
1106 };
1107 } part4;
1108
1109 struct {
1110 // part5 - compute gradK
1111 typename MatmulQK::AccumulatorSharedStorage
1112 tmpT_shared_storage; // (from part2)
1113 typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
1114 union {
1115 typename MatmulGradK::Mma::SharedStorage mm_gradK;
1116 typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
1117 };
1118 } part5;
1119
1120 struct {
1121 // part6 - store RF accumulated into gmem
1122 typename MatmulGradK::DefaultEpilogue::SharedStorage
1123 gradK_epilogue_final;
1124 typename MatmulGradV::DefaultEpilogue::SharedStorage
1125 gradV_epilogue_final;
1126 } part6;
1127 };
print_sizeAttentionBackwardKernel::SharedStorageNoPrologue1128 static void print_size() {
1129 #define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f)))
1130 printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue)));
1131 printf(" persistent: %db\n", FIELD_SIZEOF(persistent));
1132 printf(" part1: %db\n", FIELD_SIZEOF(part1));
1133 printf(" part2: %db\n", FIELD_SIZEOF(part2));
1134 printf(" part3: %db\n", FIELD_SIZEOF(part3));
1135 printf(" part4: %db\n", FIELD_SIZEOF(part4));
1136 printf(" part5: %db\n", FIELD_SIZEOF(part5));
1137 printf(" part6: %db\n", FIELD_SIZEOF(part6));
1138 }
1139 // ===========================================
1140 #define FIELD(INSIDE_STRUCT, FIELDNAME) \
1141 CUTLASS_DEVICE auto& FIELDNAME() { \
1142 return INSIDE_STRUCT.FIELDNAME; \
1143 }
1144
1145 FIELD(persistent, di)
1146 FIELD(part1, mm_qk_k)
1147 FIELD(part1, mm_qk_q)
1148 FIELD(part2, bias)
1149 FIELD(part2, attn_shared_storage)
1150 FIELD(part2, zij)
1151 FIELD(part2, mm_gradV)
1152 FIELD(part2, gradV_epilogue)
1153 FIELD(part3, mm_doivj)
1154 FIELD(part3, gradB_epilogue)
1155 FIELD(part4, tmpT_shared_storage)
1156 FIELD(part4, tmp_shared_storage)
1157 FIELD(part4, mm_gradQ)
1158 FIELD(part4, gradQ_epilogue)
1159 FIELD(part4, gradQ_epilogue_lastIter)
1160 FIELD(part5, mm_gradK)
1161 FIELD(part5, gradK_epilogue)
1162 FIELD(part6, gradK_epilogue_final)
1163 FIELD(part6, gradV_epilogue_final)
1164 };
1165
1166 using SharedStorage = typename cutlass::platform::conditional<
1167 kPreload,
1168 SharedStoragePrologue,
1169 SharedStorageNoPrologue>::type;
1170
1171 struct OutputFragments {
1172 typename MatmulGradV::Mma::FragmentC gradV;
1173 typename MatmulGradK::Mma::FragmentC gradK;
1174
clearAttentionBackwardKernel::OutputFragments1175 CUTLASS_DEVICE void clear() {
1176 gradV.clear();
1177 gradK.clear();
1178 }
1179 };
1180
check_supportedAttentionBackwardKernel1181 static bool __host__ check_supported(Params const& p) {
1182 CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
1183 CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
1184 CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
1185 CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
1186 CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
1187 CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
1188 TORCH_CHECK(
1189 p.num_heads <= 1 || p.lse_strideH % 8 == 0,
1190 "LSE is not correctly aligned (strideH)");
1191 TORCH_CHECK(
1192 p.num_batches <= 1 || p.lse_strideB % 8 == 0,
1193 "LSE is not correctly aligned (strideB)");
1194 TORCH_CHECK(
1195 p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
1196 "query is not correctly aligned (strideH)");
1197 TORCH_CHECK(
1198 p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0,
1199 "key is not correctly aligned (strideH)");
1200 TORCH_CHECK(
1201 p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0,
1202 "value is not correctly aligned (strideH)");
1203 TORCH_CHECK(
1204 p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
1205 "query is not correctly aligned (strideB)");
1206 TORCH_CHECK(
1207 p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
1208 "key is not correctly aligned (strideB)");
1209 TORCH_CHECK(
1210 p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0,
1211 "value is not correctly aligned (strideB)");
1212 TORCH_CHECK(
1213 p.q_strideM % kMinimumAlignment == 0,
1214 "query is not correctly aligned (strideM)");
1215 TORCH_CHECK(
1216 p.k_strideM % kMinimumAlignment == 0,
1217 "key is not correctly aligned (strideM)");
1218 TORCH_CHECK(
1219 p.v_strideM % kMinimumAlignment == 0,
1220 "value is not correctly aligned (strideM)");
1221 if (p.bias_ptr) {
1222 TORCH_CHECK(
1223 p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
1224 "attn_bias is not correctly aligned (strideB). ",
1225 "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a "
1226 "multiple of ", kMinimumAlignment, ".");
1227 TORCH_CHECK(
1228 p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
1229 "attn_bias is not correctly aligned (strideH) ."
1230 "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
1231 "multiple of ", kMinimumAlignment, ".");
1232 TORCH_CHECK(
1233 p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0,
1234 "attn_bias is not correctly aligned (strideM). "
1235 "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ",
1236 "multiple of ", kMinimumAlignment, ".");
1237 }
1238 if (p.grad_bias_ptr) {
1239 TORCH_CHECK(
1240 p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0,
1241 "attn_bias.grad is not correctly aligned (strideB)");
1242 TORCH_CHECK(
1243 p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0,
1244 "attn_bias.grad is not correctly aligned (strideH)");
1245 TORCH_CHECK(
1246 p.gB_strideM % kMinimumAlignment == 0,
1247 "attn_bias.grad is not correctly aligned (strideM)");
1248 }
1249 TORCH_CHECK(
1250 !(p.cu_seqlens_q_ptr && p.bias_ptr),
1251 "CuSeqlen + bias not implemented yet");
1252 TORCH_CHECK(
1253 p.custom_mask_type < NumCustomMaskTypes,
1254 "Invalid value for `custom_mask_type`");
1255 TORCH_CHECK(
1256 p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f,
1257 "Invalid value for `dropout_prob`");
1258 TORCH_CHECK(
1259 kApplyDropout || p.dropout_prob == 0.0f,
1260 "Set `kApplyDropout`=True to support `dropout_prob > 0`");
1261 TORCH_CHECK(p.head_dim > 0, "Invalid value for `head_dim`");
1262 TORCH_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`");
1263 TORCH_CHECK(p.num_queries > 0, "Invalid value for `num_queries`");
1264 TORCH_CHECK(p.num_keys > 0, "Invalid value for `num_keys`");
1265 TORCH_CHECK(p.num_heads > 0, "Invalid value for `num_heads`");
1266 TORCH_CHECK(p.num_batches > 0, "Invalid value for `num_batches`");
1267 TORCH_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`");
1268 TORCH_CHECK(
1269 p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`");
1270 if (kKeysQueriesAlignedToBlockSize) {
1271 TORCH_CHECK(
1272 p.cu_seqlens_k_ptr == nullptr,
1273 "This kernel does not support cu_seqlen");
1274 TORCH_CHECK(
1275 p.cu_seqlens_q_ptr == nullptr,
1276 "This kernel does not support cu_seqlen");
1277 TORCH_CHECK(
1278 p.num_queries % kBlockSizeI == 0,
1279 "kKeysQueriesAlignedToBlockSize condition not respected");
1280 TORCH_CHECK(
1281 p.num_keys % kBlockSizeJ == 0,
1282 "kKeysQueriesAlignedToBlockSize condition not respected");
1283 }
1284 TORCH_CHECK(
1285 kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled");
1286 TORCH_CHECK(
1287 p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)");
1288 TORCH_CHECK(
1289 p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ),
1290 "Invalid `num_splits_key` (",
1291 p.num_splits_key,
1292 ") - too large for `num_keys` = ",
1293 p.num_keys);
1294 if (p.window_size != 0) {
1295 TORCH_CHECK(
1296 p.custom_mask_type != NoCustomMask,
1297 "LocalAttention only supported in causal mode");
1298 }
1299 return true;
1300 }
1301
attention_kernelAttentionBackwardKernel1302 static CUTLASS_DEVICE void attention_kernel(Params p) {
1303 extern __shared__ char smem_buffer[];
1304 SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
1305
1306 uint16_t thread_id = threadIdx.x;
1307 uint8_t warp_id = warp_uniform(thread_id / 32);
1308 uint8_t lane_id = thread_id % 32;
1309
1310 int32_t key_start = p.split_key_device() * kBlockSizeJ;
1311 if (key_start >= p.num_keys) {
1312 return;
1313 }
1314 if (kPrologueQK) {
1315 int32_t query_start = getQueryStart(p, key_start);
1316 prologueQkNextIteration<true>(
1317 shared_storage, p, query_start, key_start, warp_id, lane_id);
1318 }
1319
1320 // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
1321 if (kKernelComputesDelta) {
1322 constexpr int kOptimalElements =
1323 128 / cutlass::sizeof_bits<scalar_t>::value;
1324 if (p.head_dim_value % kOptimalElements == 0) {
1325 for (int query_start = 0; query_start < p.num_queries;
1326 query_start += kBlockSizeI) {
1327 computeDelta<kOptimalElements>(p, query_start, warp_id, lane_id);
1328 }
1329 } else {
1330 for (int query_start = 0; query_start < p.num_queries;
1331 query_start += kBlockSizeI) {
1332 computeDelta<1>(p, query_start, warp_id, lane_id);
1333 }
1334 }
1335 __syncthreads();
1336 }
1337
1338 OutputFragments output_frags;
1339
1340 curandStatePhilox4_32_10_t rng_state_init;
1341
1342 if (kApplyDropout) {
1343 // See Note [Seed and Offset Device]
1344 auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
1345 // each element of the attention matrix P with shape
1346 // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
1347 // offset in RNG sequence. we initialize the RNG state with offset that
1348 // starts at the beginning of a (n_queries, n_keys) matrix for this
1349 // block's batch_id and head_id
1350 // initializing rng state is very expensive, so we run once per kernel,
1351 // rather than once per iteration. each iteration takes a copy of the
1352 // initialized RNG state and offsets it as needed.
1353 curand_init(
1354 std::get<0>(seeds),
1355 0,
1356 std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
1357 &rng_state_init);
1358 }
1359
1360 CUTLASS_PRAGMA_UNROLL
1361 for (; key_start < p.num_keys;
1362 key_start += p.num_splits_key_device() * kBlockSizeJ) {
1363 output_frags.clear();
1364
1365 int32_t next_key = key_start;
1366 int32_t query_start = getQueryStart(p, key_start);
1367 while (next_key == key_start && query_start < p.num_queries) {
1368 // This line here
1369 // vvvvvvvvvvvvvv
1370 warp_id = warp_uniform(warp_id);
1371 // ^^^^^^^^^^^^^^
1372 // ... makes everything use less RF and be 10% faster. Why?
1373 // I don't know. My theory is that it forces `nvcc` to
1374 // re-compute indices, offsets etc... and not keep them
1375 // from the previous iteration, which prevents MASSIVE
1376 // register spilling.
1377
1378 processBlockIJ<kKeysQueriesAlignedToBlockSize>(
1379 shared_storage,
1380 output_frags,
1381 p,
1382 query_start,
1383 key_start,
1384 rng_state_init,
1385 warp_id,
1386 lane_id);
1387
1388 int32_t next_query;
1389 incrIteration(p, query_start, key_start, next_query, next_key);
1390 query_start = next_query;
1391 }
1392 if (kOutputInRF) {
1393 writeFragsToGmem<kKeysQueriesAlignedToBlockSize>(
1394 shared_storage, output_frags, p, key_start, warp_id, lane_id);
1395 } else if (getQueryStart(p, key_start) >= p.num_queries) {
1396 zfillGradKV<kKeysQueriesAlignedToBlockSize>(
1397 p, key_start, warp_id, lane_id);
1398 }
1399 __syncthreads();
1400 }
1401 }
1402
1403 template <bool skipBoundsChecks>
zfillGradKVAttentionBackwardKernel1404 static CUTLASS_DEVICE void zfillGradKV(
1405 Params const& p,
1406 int32_t key_start,
1407 uint8_t warp_id,
1408 uint8_t lane_id) {
1409 constexpr int kThreadsPerKey = 8;
1410 constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
1411 static_assert(kBlockSizeJ % kParallelKeys == 0, "");
1412 // This function is not really optimized, but should rarely be used
1413 // It's only used when some keys are "useless" and don't attend to
1414 // any query, due to causal masking
1415
1416 int thread_id = 32 * warp_id + lane_id;
1417 int k_shift = lane_id % kThreadsPerKey;
1418
1419 CUTLASS_PRAGMA_UNROLL
1420 for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
1421 int key = key_start + j + (thread_id / kThreadsPerKey);
1422 if (!skipBoundsChecks && key >= p.num_keys) {
1423 continue;
1424 }
1425 auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
1426 auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
1427
1428 for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
1429 gv_ptr[k] = scalar_t(0);
1430 }
1431 for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) {
1432 gk_ptr[k] = scalar_t(0);
1433 }
1434 }
1435 }
1436
1437 template <bool skipBoundsChecks>
processBlockIJAttentionBackwardKernel1438 static CUTLASS_DEVICE void processBlockIJ(
1439 SharedStorage& shared_storage,
1440 OutputFragments& output_frags,
1441 Params& p,
1442 int32_t query_start,
1443 int32_t key_start,
1444 const curandStatePhilox4_32_10_t& curand_state_init,
1445 uint8_t warp_id,
1446 uint8_t lane_id) {
1447 cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
1448 dropout_keep_mask_doivj;
1449 dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
1450 const float dropout_scale =
1451 kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;
1452
1453 cutlass::MatrixCoord no_offset{0, 0};
1454 accum_t scale = p.scale;
1455 int16_t thread_id = 32 * warp_id + lane_id;
1456
1457 auto rematerializeThreadIds = [&]() {
1458 // Prevents `nvcc` from keeping values deduced from
1459 // `thread_id`, `warp_id`, ... in RF - to reduce register pressure
1460 warp_id = warp_uniform(thread_id / 32);
1461 lane_id = thread_id % 32;
1462 thread_id = 32 * warp_id + lane_id;
1463 };
1464
1465 bool isFirstQuery = (query_start == getQueryStart(p, key_start));
1466 int32_t next_query, next_key;
1467 incrIteration(p, query_start, key_start, next_query, next_key);
1468 bool isLastQuery = next_key != key_start;
1469
1470 accum_t di_rf = accum_t(0);
1471 if (thread_id < kBlockSizeI) {
1472 if (query_start + thread_id < p.num_queries) {
1473 di_rf = p.delta_ptr[query_start + thread_id];
1474 }
1475 shared_storage.di()[thread_id] = di_rf;
1476 }
1477
1478 int32_t num_queries_in_block = skipBoundsChecks
1479 ? MatmulQK::Mma::Shape::kN
1480 : warp_uniform(cutlass::fast_min(
1481 (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start));
1482 int32_t num_keys_in_block = skipBoundsChecks
1483 ? MatmulQK::Mma::Shape::kM
1484 : warp_uniform(cutlass::fast_min(
1485 (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start));
1486
1487 auto prologueGradV = [&](int col) {
1488 typename MatmulGradV::Mma::IteratorB iterator_dO(
1489 {int32_t(p.gO_strideM)},
1490 const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM + col),
1491 {num_queries_in_block, p.head_dim_value - col},
1492 thread_id,
1493 no_offset);
1494 MatmulGradV::Mma::prologue(
1495 shared_storage.mm_gradV(),
1496 iterator_dO,
1497 thread_id,
1498 num_queries_in_block);
1499 };
1500 auto prologueGradQ = [&](int col) {
1501 typename MatmulGradQ::Mma::IteratorB iterator_K(
1502 {int32_t(p.k_strideM)},
1503 const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM + col),
1504 {num_keys_in_block, p.head_dim - col},
1505 thread_id,
1506 no_offset);
1507 MatmulGradQ::Mma::prologue(
1508 shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block);
1509 };
1510 auto prologueGradK = [&](int col) {
1511 typename MatmulGradK::Mma::IteratorB iterator_Q(
1512 {int32_t(p.q_strideM)},
1513 const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM + col),
1514 {num_queries_in_block, p.head_dim - col},
1515 thread_id,
1516 no_offset);
1517 MatmulGradK::Mma::prologue(
1518 shared_storage.mm_gradK(),
1519 iterator_Q,
1520 thread_id,
1521 num_queries_in_block);
1522 };
1523 auto prologueDOV = [&]() {
1524 typename MatmulDOIVJ::Mma::IteratorA iterator_A(
1525 {int32_t(p.gO_strideM)},
1526 const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM),
1527 {num_queries_in_block, p.head_dim_value},
1528 thread_id,
1529 no_offset);
1530 typename MatmulDOIVJ::Mma::IteratorB iterator_B(
1531 {int32_t(p.v_strideM)},
1532 const_cast<scalar_t*>(p.value_ptr + key_start * p.v_strideM),
1533 {p.head_dim_value, num_keys_in_block},
1534 thread_id,
1535 no_offset);
1536 MatmulDOIVJ::Mma::prologue(
1537 shared_storage.mm_doivj(),
1538 iterator_A,
1539 iterator_B,
1540 thread_id,
1541 p.head_dim_value);
1542 };
1543
1544 /////////////////////////////////////////////////////////////////////////////////////////////////
1545 // MatmulQK
1546 /////////////////////////////////////////////////////////////////////////////////////////////////
1547 {
1548 using Mma = typename MatmulQK::Mma;
1549
1550 cutlass::gemm::GemmCoord problem_size(
1551 num_keys_in_block,
1552 num_queries_in_block,
1553 p.head_dim // k
1554 );
1555
1556 // k_j
1557 typename Mma::IteratorA iterator_A(
1558 {int32_t(p.k_strideM)},
1559 const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM),
1560 {problem_size.m(), problem_size.k()},
1561 thread_id,
1562 no_offset);
1563
1564 // q_i.transpose(-2, -1)
1565 typename Mma::IteratorB iterator_B(
1566 {int32_t(p.q_strideM)},
1567 const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM),
1568 {problem_size.k(), problem_size.n()},
1569 thread_id,
1570 no_offset);
1571
1572 Mma mma(
1573 shared_storage.mm_qk_k(),
1574 shared_storage.mm_qk_q(),
1575 thread_id,
1576 warp_id,
1577 lane_id);
1578
1579 typename Mma::FragmentC accum;
1580
1581 accum.clear();
1582
1583 auto gemm_k_iterations =
1584 (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
1585
1586 // Compute threadblock-scoped matrix multiply-add
1587 mma.set_prologue_done(kPrologueQK);
1588 mma.set_zero_outside_bounds(!skipBoundsChecks);
1589 mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
1590 accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
1591
1592 // Epilogue: add LSE + exp and store that to our shared memory buffer
1593 // shmem <- (matmul_result -
1594 // logsumexp[i_start:i_end].unsqueeze(1)).exp()
1595 int warp_idx_mn_0 =
1596 warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
1597 auto output_tile_coords = cutlass::MatrixCoord{
1598 warp_idx_mn_0 % Mma::Base::WarpCount::kM,
1599 warp_idx_mn_0 / Mma::Base::WarpCount::kM};
1600
1601 // apply bias if applicable
1602 if (p.bias_ptr != nullptr) {
1603 // load bias tile Bij into shared memory
1604 typename MatmulQK::BiasLoader::GmemTileIterator bias_iter(
1605 {cutlass::layout::RowMajor(p.bias_strideM)},
1606 const_cast<scalar_t*>(p.bias_ptr + query_start * p.bias_strideM + key_start),
1607 {num_queries_in_block, num_keys_in_block},
1608 thread_id);
1609 cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
1610 shared_storage.bias().data(),
1611 cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM));
1612 typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter(
1613 bias_tensor_ref, thread_id);
1614 MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter);
1615
1616 // Pij += Bij, where Pij is in register fragment and Bij is in shmem
1617 auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
1618 lane_id, warp_id, output_tile_coords);
1619 MatmulQK::AccumLambdaIterator::iterateRows(
1620 lane_offset,
1621 [&](int accum_n) {},
1622 [&](int accum_m, int accum_n, int idx) {
1623 // remember we are transposed
1624 accum[idx] += bias_tensor_ref.at({accum_n, accum_m});
1625 },
1626 [&](int accum_n) {});
1627 }
1628
1629 // Apply mask
1630 if (p.custom_mask_type == CausalFromTopLeft ||
1631 p.custom_mask_type == CausalFromBottomRight) {
1632 auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
1633 lane_id, warp_id, output_tile_coords);
1634 int shift = query_start - key_start;
1635 if (p.custom_mask_type == CausalFromBottomRight) {
1636 shift += p.num_keys - p.num_queries;
1637 }
1638 // current_key = key_start + accum_m
1639 // current_query = query_start + accum_n
1640 // mask if: `current_key > current_query`
1641 MatmulQK::AccumLambdaIterator::iterateRows(
1642 lane_offset,
1643 [&](int accum_m) {},
1644 [&](int accum_m, int accum_n, int idx) {
1645 if (accum_m > accum_n + shift) {
1646 accum[idx] =
1647 -cutlass::platform::numeric_limits<accum_t>::infinity();
1648 }
1649 },
1650 [&](int accum_m) {});
1651 }
1652 if (p.window_size > 0) {
1653 auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
1654 lane_id, warp_id, output_tile_coords);
1655 int shift = query_start - key_start - p.window_size;
1656 // current_key = key_start + accum_m
1657 // current_query = query_start + accum_n
1658 // mask if: `current_key < current_query - window_size`
1659 // if accum_m < accum_n + query_start - window_size - key_start
1660
1661 MatmulQK::AccumLambdaIterator::iterateRows(
1662 lane_offset,
1663 [&](int accum_m) {},
1664 [&](int accum_m, int accum_n, int idx) {
1665 if (accum_m <= accum_n + shift) {
1666 accum[idx] =
1667 -cutlass::platform::numeric_limits<accum_t>::infinity();
1668 }
1669 },
1670 [&](int accum_m) {});
1671 }
1672 __syncthreads();
1673 if (kPrologueGV) {
1674 prologueGradV(0);
1675 }
1676 if (kPrologueDOV) {
1677 prologueDOV();
1678 }
1679
1680 MatmulQK::B2bGemm::accumApplyLSEToSmem(
1681 shared_storage.attn_shared_storage(),
1682 accum,
1683 p.logsumexp_ptr + query_start,
1684 problem_size.n(),
1685 thread_id,
1686 warp_id,
1687 lane_id,
1688 output_tile_coords);
1689 #if 0
1690 auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref();
1691 PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT);
1692 #endif
1693
1694 // if we are using dropout, compute Zij, writing it to shared memory.
1695 // each element of Zij is:
1696 // - 0 with probability dropout_p
1697 // - 1 / (1 - dropout_p) with probability 1 - dropout_p
1698 if (kApplyDropout) {
1699 auto zij = shared_storage.zij().accum_ref();
1700 // each thread generates a contiguous sequence of elements in Zij, all
1701 // in the same row. the reason they have to come from the same row is
1702 // that sampling random numbers from a contiguous random number sequence
1703 // is much more efficient than jumping around, and the linear offset of
1704 // each element of Z (the global matrix) maps to an offset in a random
1705 // number sequence. for Z, the end of a row and the beginning of the
1706 // next have adjacent offsets, but for Zij (tile of global matrix), this
1707 // is not necessarily the case.
1708 // We must fill the entire `zij` shmem with values (even out of bounds
1709 // on the K-dimension) otherwise we can get NaNs during the GEMM
1710 const int kQueriesPerBlock = kBlockSizeI;
1711 const int threads_per_row = cutlass::fast_min(
1712 int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block);
1713 const int elts_per_thread = cutlass::round_nearest(
1714 cutlass::ceil_div(num_keys_in_block, threads_per_row), 4);
1715
1716 const int thread_i = thread_id / threads_per_row;
1717 const int thread_start_j =
1718 (thread_id % threads_per_row) * elts_per_thread;
1719
1720 if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) {
1721 curandStatePhilox4_32_10_t curand_state = curand_state_init;
1722 skipahead(
1723 (query_start + thread_i) * p.num_keys +
1724 (key_start + thread_start_j),
1725 &curand_state);
1726
1727 // generate elements of Zij, 4 elements at a time
1728 for (int zij_start_col_idx = thread_start_j; zij_start_col_idx <
1729 cutlass::fast_min<int32_t>(thread_start_j + elts_per_thread,
1730 num_keys_in_block);
1731 zij_start_col_idx += 4) {
1732 const float4 rand_uniform_quad = curand_uniform4(&curand_state);
1733
1734 CUTLASS_PRAGMA_UNROLL
1735 for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
1736 // we'll write Zij transposed since attention is also transposed
1737 // during the matmul to compute dV.
1738 zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) =
1739 (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob
1740 ? scalar_t(dropout_scale)
1741 : scalar_t(0);
1742 }
1743 }
1744 }
1745 __syncthreads();
1746 #if 0
1747 PRINT_TENSOR4x4_T0_L0("zij", zij);
1748 PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4);
1749 #endif
1750
1751 // Save mask for later DOIVJ matmul
1752
1753 int warp_idx_mn_0 = warp_id %
1754 (MatmulDOIVJ::Mma::Base::WarpCount::kM *
1755 MatmulDOIVJ::Mma::Base::WarpCount::kN);
1756 auto output_tile_coords_doivj = cutlass::MatrixCoord{
1757 warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM,
1758 warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM};
1759 auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset(
1760 lane_id, warp_id, output_tile_coords_doivj);
1761 MatmulDOIVJ::AccumLambdaIterator::iterateRows(
1762 lane_offset,
1763 [&](int accum_m) {},
1764 [&](int accum_m /*q*/, int accum_n /*k*/, int idx) {
1765 if (zij.at({accum_n, accum_m}) == scalar_t(0)) {
1766 dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0};
1767 }
1768 },
1769 [&](int accum_m) {});
1770 }
1771 __syncthreads();
1772 }
1773 rematerializeThreadIds();
1774
1775 /////////////////////////////////////////////////////////////////////////////////////////////////
1776 // GradV matmul
1777 //
1778 // grad_v[j_start:j_end] += attn_T @ do_i
1779 /////////////////////////////////////////////////////////////////////////////////////////////////
1780 constexpr bool kSingleIterationGradV =
1781 kMaxK <= MatmulGradV::ThreadblockShape::kN;
1782 for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value);
1783 col += MatmulGradV::ThreadblockShape::kN) {
1784 using Mma = typename MatmulGradV::Mma;
1785 using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
1786
1787 cutlass::gemm::GemmCoord problem_size(
1788 num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
1789 auto createEpilogueIter = [&]() {
1790 return typename MatmulGradV::OutputTileIterator(
1791 typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
1792 p.grad_value_ptr + key_start * p.gV_strideM() + col,
1793 {num_keys_in_block, p.head_dim_value - col},
1794 thread_id);
1795 };
1796 typename Mma::IteratorB iterator_B(
1797 {int32_t(p.gO_strideM)},
1798 const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM + col),
1799 {num_queries_in_block, p.head_dim_value - col},
1800 thread_id,
1801 no_offset);
1802
1803 // if dropout: dVj += (Pij.T * Zij) @ dOi
1804 // otherwise: dVj += Pij.T @ dOi
1805 Mma mma(
1806 // operand A: Pij.T
1807 shared_storage.attn_shared_storage().accum_ref(),
1808 // operand A_scale Zij.T:
1809 // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T
1810 // which is computed on the fly as fragments of Pij.T are loaded in
1811 shared_storage.zij().accum_ref(),
1812 // operand B: dOi - which was loaded into shared memory previously
1813 // when we computed dVj
1814 shared_storage.mm_gradV().operand_B_ref(),
1815 thread_id,
1816 warp_id,
1817 lane_id);
1818
1819 int storage_id = col / MatmulGradV::ThreadblockShape::kN;
1820 AccumTileGmem gmem_tile{
1821 p.workspace_gv + storage_id * AccumTileGmem::kElementsStored};
1822 if (!kOutputInRF) {
1823 if (isFirstQuery || !kNeedsAccumGradV) {
1824 output_frags.gradV.clear();
1825 } else {
1826 gmem_tile.load(output_frags.gradV, thread_id);
1827 }
1828 }
1829 mma.set_prologue_done(kPrologueGV);
1830
1831 auto gemm_k_iterations =
1832 (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
1833
1834 // Compute threadblock-scoped matrix multiply-add
1835 __syncthreads();
1836
1837 mma(gemm_k_iterations,
1838 output_frags.gradV,
1839 iterator_B,
1840 output_frags.gradV);
1841 __syncthreads();
1842 if (kPrologueGV && !kSingleIterationGradV &&
1843 col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) {
1844 prologueGradV(col + MatmulGradV::ThreadblockShape::kN);
1845 }
1846
1847 if (!kOutputInRF) {
1848 if (kNeedsAccumGradV && !isLastQuery) {
1849 gmem_tile.store(output_frags.gradV, thread_id);
1850 } else {
1851 accumulateInGmem<MatmulGradV>(
1852 shared_storage.gradV_epilogue(),
1853 output_frags.gradV,
1854 createEpilogueIter(),
1855 isFirstQuery || kNeedsAccumGradV,
1856 warp_id,
1857 lane_id);
1858 }
1859 }
1860 }
1861 __syncthreads();
1862
1863 /////////////////////////////////////////////////////////////////////////////////////////////////
1864 // MatmulDOIVJ
1865 /////////////////////////////////////////////////////////////////////////////////////////////////
1866 {
1867 using Mma = typename MatmulDOIVJ::Mma;
1868 // do_i
1869 typename Mma::IteratorA iterator_A(
1870 {int32_t(p.gO_strideM)},
1871 const_cast<scalar_t*>(p.grad_output_ptr + query_start * p.gO_strideM),
1872 {num_queries_in_block, p.head_dim_value},
1873 thread_id,
1874 no_offset);
1875
1876 // v_j.transpose(-2, -1)
1877 typename Mma::IteratorB iterator_B(
1878 {int32_t(p.v_strideM)},
1879 const_cast<scalar_t*>(p.value_ptr + key_start * p.v_strideM),
1880 {p.head_dim_value, num_keys_in_block},
1881 thread_id,
1882 no_offset);
1883
1884 Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id);
1885 mma.set_prologue_done(kPrologueDOV);
1886 mma.set_zero_outside_bounds(!skipBoundsChecks);
1887
1888 typename Mma::FragmentC accum;
1889
1890 accum.clear();
1891
1892 auto gemm_k_iterations =
1893 (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK;
1894
1895 // Compute threadblock-scoped matrix multiply-add
1896 mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
1897 __syncthreads();
1898 if (kPrologueGQ) {
1899 prologueGradQ(0);
1900 }
1901 if (kPrologueGK) {
1902 prologueGradK(0);
1903 }
1904
1905 int warp_idx_mn_0 =
1906 warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
1907 auto output_tile_coords = cutlass::MatrixCoord{
1908 warp_idx_mn_0 % Mma::Base::WarpCount::kM,
1909 warp_idx_mn_0 / Mma::Base::WarpCount::kM};
1910 // TODO: This must be terribly inefficient. There must be a better way
1911 // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
1912 // attn_shared_storage [smem] <- tmp.T
1913 // tmp_shared_storage [smem] <- tmp
1914 {
1915 using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator;
1916 auto lane_offset = LambdaIterator::get_lane_offset(
1917 lane_id, warp_id, output_tile_coords);
1918 // if dropout was used, compute dPij = dPij_dropped * Zij
1919 if (kApplyDropout) {
1920 LambdaIterator::iterateRows(
1921 lane_offset,
1922 [&](int accum_m) {},
1923 [&](int accum_m, int accum_n, int idx) {
1924 if (dropout_keep_mask_doivj[idx].get()) {
1925 accum[idx] *= dropout_scale;
1926 } else {
1927 accum[idx] = 0;
1928 }
1929 },
1930 [&](int accum_m) {});
1931 }
1932
1933 auto attn_T = shared_storage.attn_shared_storage().accum_ref();
1934 #if 0
1935 PRINT_B0_T0("doivj_dropped");
1936 print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4);
1937 PRINT_TENSOR4x4_T0_L0("attn_T", attn_T)
1938 #endif
1939 accum_t current_di;
1940 // dSij = (dPij - Di) * Pij
1941 LambdaIterator::iterateRows(
1942 lane_offset,
1943 [&](int accum_m) { current_di = shared_storage.di()[accum_m]; },
1944 [&](int accum_m, int accum_n, int idx) {
1945 // TODO: Otherwise we can get nans as we
1946 // might have infs here (only seen on f16 tho)
1947 if (skipBoundsChecks ||
1948 (accum_m < num_queries_in_block &&
1949 accum_n < num_keys_in_block)) {
1950 accum_t attn = attn_T.at({accum_n, accum_m});
1951 accum[idx] = (accum[idx] - current_di) * attn;
1952 } else {
1953 accum[idx] = 0;
1954 }
1955 },
1956 [&](int accum_m) {
1957
1958 });
1959
1960 // store bias gradient tile dBij to global memory,
1961 // where dBij = dSij = Pij * (dPij - Di)
1962 if (p.grad_bias_ptr != nullptr) {
1963 typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator
1964 output_iter(
1965 typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator::
1966 Params{p.gB_strideM},
1967 // grad_bias_ptr is offset to point at beginning of
1968 // matrix of shape (queries, keys) for a given
1969 // (batch_id, head_id) the pointer arithmetic here produces
1970 // a pointer to the start of the current tile within that
1971 // matrix
1972 p.grad_bias_ptr + query_start * p.gB_strideM + key_start,
1973 {num_queries_in_block, num_keys_in_block},
1974 thread_id);
1975
1976 // no-op epilogue operator - just casting and storing contents of
1977 // accum to global memory
1978 typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
1979 typename MatmulDOIVJ::BiasGradEpilogue epilogue(
1980 shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
1981 epilogue(output_op, output_iter, accum, output_iter);
1982 }
1983
1984 accum = accum * scale;
1985
1986 #if 0
1987 PRINT_B0_T0("(doivj - di) * attn * scale");
1988 print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4);
1989 #endif
1990
1991 __syncthreads();
1992 if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) {
1993 auto tmpT = shared_storage.tmpT_shared_storage().accum_ref();
1994 // attn <- attn_T.T
1995 LambdaIterator::iterateRows(
1996 lane_offset,
1997 [&](int accum_m) {},
1998 [&](int accum_m, int accum_n, int idx) {
1999 tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]);
2000 },
2001 [&](int accum_m) {});
2002 }
2003 }
2004
2005 MatmulDOIVJ::B2bGemm::accumToSmem(
2006 shared_storage.tmp_shared_storage(),
2007 accum,
2008 lane_id,
2009 output_tile_coords);
2010 __syncthreads();
2011 }
2012 // Force `nvcc` to recompute values that depend on the variables just below
2013 // to use less RF and prevent some spilling
2014 p.head_dim = warp_uniform(p.head_dim);
2015 p.k_strideM = warp_uniform(p.k_strideM);
2016 rematerializeThreadIds();
2017
2018 /////////////////////////////////////////////////////////////////////////////////////////////////
2019 // GradQ matmul
2020 //
2021 // grad_q[i_start:i_end] += tmp @ k_j
2022 /////////////////////////////////////////////////////////////////////////////////////////////////
2023 // Skip the loop & associated branches if we know at compile time the number
2024 // of iterations
2025 constexpr bool kSingleIterationGradQ =
2026 kMaxK <= MatmulGradQ::ThreadblockShape::kN;
2027 for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim);
2028 col += MatmulGradQ::ThreadblockShape::kN) {
2029 using Mma = typename MatmulGradQ::Mma;
2030 using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
2031
2032 cutlass::gemm::GemmCoord problem_size(
2033 num_queries_in_block,
2034 false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
2035 num_keys_in_block);
2036
2037 // k_j
2038 typename Mma::IteratorB iterator_B(
2039 {int32_t(p.k_strideM)},
2040 const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM + col),
2041 {problem_size.k(), problem_size.n()},
2042 thread_id,
2043 no_offset);
2044
2045 auto a = shared_storage.tmp_shared_storage().accum_ref();
2046 Mma mma(
2047 // operand A: dSij
2048 shared_storage.tmp_shared_storage().accum_ref(),
2049 // operand B: Kj
2050 shared_storage.mm_gradQ().operand_B_ref(),
2051 thread_id,
2052 warp_id,
2053 lane_id);
2054
2055 typename Mma::FragmentC accum;
2056
2057 int col_id = col / MatmulGradQ::ThreadblockShape::kN;
2058 int num_cols = kSingleIterationGradQ
2059 ? 1
2060 : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN);
2061 int storage_id = (col_id + query_start / kBlockSizeI * num_cols);
2062
2063 if (p.num_splits_key_device() > 1) {
2064 AtomicLock::acquire(
2065 &p.workspace_gq[storage_id].lock,
2066 p.split_key_device() + 1,
2067 thread_id);
2068 // Make sure we can see other block's output
2069 __threadfence();
2070 }
2071
2072 AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]};
2073 if (!kNeedsAccumGradQ ||
2074 (p.num_splits_key_device() == 1 && key_start == 0)) {
2075 // if we know we are the first to access it, we know it's only zeros.
2076 // Avoids a load from gmem (and gmem init as well)
2077 accum.clear();
2078 } else {
2079 gmem_tile.load(accum, thread_id);
2080 }
2081
2082 auto gemm_k_iterations =
2083 (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
2084
2085 // Compute threadblock-scoped matrix multiply-add
2086 __syncthreads();
2087 mma.set_prologue_done(kPrologueGQ);
2088 mma(gemm_k_iterations, accum, iterator_B, accum);
2089 __syncthreads();
2090 bool isLastColumn = kSingleIterationGradQ ||
2091 (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim);
2092 if (kPrologueGQ && !isLastColumn) {
2093 prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN);
2094 }
2095
2096 bool isLast = [&]() {
2097 int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ;
2098 if (p.num_keys <= next_key) {
2099 return true;
2100 }
2101 if (query_start < getSmallestQueryForKey(p, next_key)) {
2102 return true;
2103 }
2104 return false;
2105 }();
2106 // Output results
2107 if (p.num_splits_key_device() > 1) {
2108 int32_t numAddsSoFar = -1;
2109 if (isLast && thread_id == 0) {
2110 numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) +
2111 1; // `atomicAdd` returns the old value
2112 }
2113 isLast = __syncthreads_or(
2114 numAddsSoFar == getNumParallelBlocksForQuery(p, query_start));
2115 assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start));
2116 }
2117 if (kNeedsAccumGradQ && !isLast) {
2118 gmem_tile.store(accum, thread_id);
2119 if (p.num_splits_key_device() > 1) {
2120 // Make sure everyone wrote before we release the lock
2121 __threadfence();
2122 __syncthreads();
2123 AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id);
2124 }
2125 } else {
2126 // NOTE: We're not releasing the lock because no one is expected
2127 // to come after us (we're the last one to write)
2128 typename MatmulGradQ::OutputTileIterator output_it(
2129 typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
2130 p.grad_query_ptr + query_start * p.gQ_strideM() + col,
2131 {problem_size.m(), problem_size.n()},
2132 thread_id);
2133 // if `direct_store` is True, we store to gmem (`*gmem = accum`)
2134 // otherwise, we accumulate in gmem (`*gmem = *gmem + accum`)
2135 // If we know ahead of time when we will write for the first time
2136 // we can:
2137 // (1) Avoid an additional memory read
2138 // (2) Avoid the cost of initializing memory to 0
2139 bool direct_store = kNeedsAccumGradQ || key_start == 0 ||
2140 (p.num_splits_key_device() > 1);
2141 accumulateInGmem<MatmulGradQ>(
2142 isLastColumn ? shared_storage.gradQ_epilogue_lastIter()
2143 : shared_storage.gradQ_epilogue(),
2144 accum,
2145 output_it,
2146 direct_store,
2147 warp_id,
2148 lane_id);
2149 }
2150 }
2151 /////////////////////////////////////////////////////////////////////////////////////////////////
2152 // GradK matmul
2153 //
2154 // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
2155 /////////////////////////////////////////////////////////////////////////////////////////////////
2156 rematerializeThreadIds();
2157
2158 constexpr bool kSingleIterationGradK =
2159 kMaxK <= MatmulGradK::ThreadblockShape::kN;
2160 for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim);
2161 col += MatmulGradK::ThreadblockShape::kN) {
2162 using Mma = typename MatmulGradK::Mma;
2163 using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
2164
2165 cutlass::gemm::GemmCoord problem_size(
2166 num_keys_in_block,
2167 false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col,
2168 num_queries_in_block);
2169 auto createEpilogueIter = [&]() {
2170 return typename MatmulGradK::OutputTileIterator(
2171 typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
2172 p.grad_key_ptr + key_start * p.gK_strideM() + col,
2173 {num_keys_in_block,
2174 false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
2175 thread_id);
2176 };
2177
2178 // q_i
2179 typename Mma::IteratorB iterator_B(
2180 {int32_t(p.q_strideM)},
2181 const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM + col),
2182 {problem_size.k(), problem_size.n()},
2183 thread_id,
2184 no_offset);
2185
2186 auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); };
2187 auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); };
2188 // this is basically:
2189 // opA = kIsTransposedA ? getTmp() : getTmpT();
2190 bool constexpr kIsTransposedA =
2191 MatmulGradK::DefaultMmaFromSmem::kIsTransposedA;
2192 auto& opA = *call_conditional<
2193 kIsTransposedA,
2194 decltype(getTmp),
2195 decltype(getTmpT)>::apply(getTmp, getTmpT, 0);
2196 Mma mma(
2197 // operand A: dSij.T
2198 opA.accum_ref(),
2199 // operand B: Qi
2200 shared_storage.mm_gradK().operand_B_ref(),
2201 thread_id,
2202 warp_id,
2203 lane_id);
2204
2205 int storage_id = col / MatmulGradK::ThreadblockShape::kN;
2206 AccumTileGmem gmem_tile{
2207 p.workspace + storage_id * AccumTileGmem::kElementsStored};
2208 if (!kOutputInRF) {
2209 if (isFirstQuery || !kNeedsAccumGradK) {
2210 output_frags.gradK.clear();
2211 } else {
2212 gmem_tile.load(output_frags.gradK, thread_id);
2213 }
2214 }
2215 mma.set_prologue_done(kPrologueGK);
2216
2217 auto gemm_k_iterations =
2218 (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
2219
2220 // Compute threadblock-scoped matrix multiply-add
2221 __syncthreads();
2222
2223 mma(gemm_k_iterations,
2224 output_frags.gradK,
2225 iterator_B,
2226 output_frags.gradK);
2227 __syncthreads();
2228 bool isLastColumn = kSingleIterationGradK ||
2229 col + MatmulGradK::ThreadblockShape::kN >= p.head_dim;
2230 if (kPrologueGK && !isLastColumn) {
2231 prologueGradK(col + MatmulGradK::ThreadblockShape::kN);
2232 }
2233
2234 if (kPrologueQK && isLastColumn) {
2235 int32_t next_query, next_key;
2236 incrIteration(p, query_start, key_start, next_query, next_key);
2237 DISPATCH_BOOL(
2238 next_key != key_start, kForceReloadK, ([&]() {
2239 prologueQkNextIteration<kForceReloadK>(
2240 shared_storage, p, next_query, next_key, warp_id, lane_id);
2241 }));
2242 }
2243
2244 // Output results
2245 if (!kOutputInRF) {
2246 if (kNeedsAccumGradK && !isLastQuery) {
2247 gmem_tile.store(output_frags.gradK, thread_id);
2248 } else {
2249 accumulateInGmem<MatmulGradK>(
2250 isLastColumn ? shared_storage.gradK_epilogue_final()
2251 : shared_storage.gradK_epilogue(),
2252 output_frags.gradK,
2253 createEpilogueIter(),
2254 isFirstQuery || kNeedsAccumGradK,
2255 warp_id,
2256 lane_id);
2257 __syncthreads();
2258 }
2259 }
2260 }
2261 }
2262
getQueryStartShiftAttentionBackwardKernel2263 static CUTLASS_HOST_DEVICE int32_t getQueryStartShift(Params const& p) {
2264 if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) {
2265 return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p);
2266 }
2267 return 0;
2268 }
2269
2270 // Iteration order logic
2271 static CUTLASS_HOST_DEVICE int32_t
getQueryStartAttentionBackwardKernel2272 getQueryStart(Params const& p, int32_t key_start) {
2273 return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p);
2274 };
getQueryEndAttentionBackwardKernel2275 static CUTLASS_HOST_DEVICE int32_t getQueryEnd(Params const& p) {
2276 return align_up(p.num_queries, kBlockSizeI);
2277 };
2278
2279 static CUTLASS_HOST_DEVICE int32_t
getSmallestQueryForKeyAttentionBackwardKernel2280 getSmallestQueryForKey(Params const& p, int32_t key_start) {
2281 if (p.custom_mask_type == NoCustomMask) {
2282 return 0;
2283 }
2284 int32_t shift = p.custom_mask_type == CausalFromBottomRight
2285 ? p.num_keys - p.num_queries
2286 : 0;
2287 int32_t window_size =
2288 p.window_size == 0 ? p.num_queries + p.num_keys : p.window_size;
2289
2290 auto last_key_for_block =
2291 cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1;
2292 int first_query = key_start - shift;
2293 int last_query = last_key_for_block - shift + window_size - 1;
2294 if (last_query < 0 || first_query >= p.num_queries) {
2295 return getQueryEnd(p); // nothing to compute in this column
2296 }
2297 first_query = cutlass::fast_max(0, first_query);
2298 return (first_query / kBlockSizeI) * kBlockSizeI;
2299 };
2300
2301 // Returns how many kernel blocks will write to a given block in `grad_query`
2302 // This is usually equal to the number of key splits, but can be different
2303 // for instance in the causal case, or varying seqlen
2304 static CUTLASS_HOST_DEVICE int32_t
getNumParallelBlocksForQueryAttentionBackwardKernel2305 getNumParallelBlocksForQuery(Params const& p, int32_t query_start) {
2306 int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ);
2307 if (p.custom_mask_type != NoCustomMask) {
2308 int32_t shift = p.custom_mask_type == CausalFromBottomRight
2309 ? p.num_keys - p.num_queries
2310 : 0;
2311 int32_t last_query_for_block =
2312 cutlass::fast_min(query_start + kBlockSizeI, p.num_queries) - 1;
2313 int32_t last_key_for_block =
2314 cutlass::fast_min(last_query_for_block + shift, p.num_keys - 1);
2315 int32_t first_key_for_block = p.window_size == 0
2316 ? 0
2317 : cutlass::fast_max(query_start - p.window_size + 1 + shift, 0);
2318
2319 if (p.window_size == 0) {
2320 num_key_blocks = last_key_for_block / kBlockSizeJ + 1;
2321 } else {
2322 num_key_blocks = (last_key_for_block / kBlockSizeJ) -
2323 (first_key_for_block / kBlockSizeJ) + 1;
2324 }
2325
2326 if (last_key_for_block < 0 || first_key_for_block >= p.num_keys) {
2327 num_key_blocks = 0;
2328 }
2329 }
2330 return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks);
2331 };
2332
2333 // Returns the next block to process
incrIterationAttentionBackwardKernel2334 static CUTLASS_HOST_DEVICE void incrIteration(
2335 Params const& p,
2336 int32_t query_start,
2337 int32_t key_start,
2338 int32_t& next_query,
2339 int32_t& next_key) {
2340 next_query = query_start + kBlockSizeI;
2341 next_key = key_start;
2342 auto query_shift = getQueryStartShift(p);
2343 // Wrap around
2344 if (query_shift) {
2345 if (next_query >= p.num_queries) {
2346 next_query = getSmallestQueryForKey(p, key_start);
2347 return;
2348 } else if (query_start < query_shift && query_shift <= next_query) {
2349 // jump to next key
2350 } else {
2351 return;
2352 }
2353 } else {
2354 if (p.window_size > 0) {
2355 int32_t shift = p.custom_mask_type == CausalFromBottomRight
2356 ? p.num_keys - p.num_queries
2357 : 0;
2358 // last key that is not masked out
2359 int last_key_for_block =
2360 cutlass::fast_min(key_start + kBlockSizeJ, p.num_keys) - 1;
2361 int last_query = last_key_for_block - shift + p.window_size - 1;
2362 if (next_query <= last_query && next_query < p.num_queries) {
2363 return;
2364 }
2365 } else if (next_query < p.num_queries) {
2366 return;
2367 }
2368 // jump to next key
2369 }
2370 // Next key
2371 next_key = key_start + p.num_splits_key_device() * kBlockSizeJ;
2372 next_query = getQueryStart(p, next_key);
2373 }
2374
2375 template <bool kForceReloadK>
prologueQkNextIterationAttentionBackwardKernel2376 static CUTLASS_DEVICE void prologueQkNextIteration(
2377 SharedStorage& shared_storage,
2378 Params const& p,
2379 int32_t query_start,
2380 int32_t key_start,
2381 uint8_t warp_id,
2382 uint8_t lane_id) {
2383 if (query_start >= p.num_queries || key_start >= p.num_keys) {
2384 return;
2385 }
2386
2387 static constexpr bool kReloadK =
2388 kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
2389 int thread_id = 32 * warp_id + lane_id;
2390 typename MatmulQK::Mma::IteratorA iterator_A(
2391 {int32_t(p.k_strideM)},
2392 const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM),
2393 {p.num_keys - key_start, p.head_dim},
2394 thread_id,
2395 cutlass::MatrixCoord{0, 0});
2396
2397 typename MatmulQK::Mma::IteratorB iterator_B(
2398 {int32_t(p.q_strideM)},
2399 const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM),
2400 {p.head_dim, p.num_queries - query_start},
2401 thread_id,
2402 cutlass::MatrixCoord{0, 0});
2403
2404 MatmulQK::Mma::prologue<kReloadK, true>(
2405 shared_storage.mm_qk_k(),
2406 shared_storage.mm_qk_q(),
2407 iterator_A,
2408 iterator_B,
2409 thread_id,
2410 p.head_dim);
2411 }
2412
2413 template <bool skipBoundsChecks>
writeFragsToGmemAttentionBackwardKernel2414 static CUTLASS_DEVICE void writeFragsToGmem(
2415 SharedStorage& shared_storage,
2416 OutputFragments& output_frags,
2417 Params const& p,
2418 int32_t key_start,
2419 uint8_t warp_id,
2420 uint8_t lane_id) {
2421 uint16_t thread_id = 32 * warp_id + lane_id;
2422 int32_t num_keys_in_block = skipBoundsChecks
2423 ? MatmulQK::Mma::Shape::kM
2424 : cutlass::fast_min(
2425 (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start);
2426 typename MatmulGradV::OutputTileIterator outputV_it(
2427 typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
2428 p.grad_value_ptr + key_start * p.gV_strideM(),
2429 {num_keys_in_block, p.head_dim_value},
2430 thread_id);
2431 accumulateInGmem<MatmulGradV>(
2432 shared_storage.gradV_epilogue_final(),
2433 output_frags.gradV,
2434 outputV_it,
2435 true,
2436 warp_id,
2437 lane_id);
2438
2439 typename MatmulGradK::OutputTileIterator outputK_it(
2440 typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
2441 p.grad_key_ptr + key_start * p.gK_strideM(),
2442 {num_keys_in_block,
2443 false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
2444 thread_id);
2445 accumulateInGmem<MatmulGradK>(
2446 shared_storage.gradK_epilogue_final(),
2447 output_frags.gradK,
2448 outputK_it,
2449 true,
2450 warp_id,
2451 lane_id);
2452 }
2453
2454 template <typename MatmulT>
accumulateInGmemAttentionBackwardKernel2455 static CUTLASS_DEVICE void accumulateInGmem(
2456 typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem,
2457 typename MatmulT::Mma::FragmentC const& accum,
2458 typename MatmulT::OutputTileIterator output_it,
2459 bool first,
2460 uint8_t warp_id,
2461 uint8_t lane_id) {
2462 using DefaultEpilogue = typename MatmulT::DefaultEpilogue;
2463 using DefaultOutputOp = typename MatmulT::DefaultOutputOp;
2464 using Mma = typename MatmulT::Mma;
2465 int thread_id = 32 * warp_id + lane_id;
2466 DISPATCH_BOOL(
2467 first, kIsFirst, ([&]() {
2468 static constexpr auto ScaleType = kIsFirst
2469 ? cutlass::epilogue::thread::ScaleType::Nothing
2470 : cutlass::epilogue::thread::ScaleType::NoBetaScaling;
2471 using EpilogueOutputOp =
2472 typename cutlass::epilogue::thread::LinearCombination<
2473 typename DefaultOutputOp::ElementOutput,
2474 DefaultOutputOp::kCount,
2475 typename DefaultOutputOp::ElementAccumulator,
2476 typename DefaultOutputOp::ElementCompute,
2477 ScaleType>;
2478 using Epilogue =
2479 typename cutlass::epilogue::threadblock::EpiloguePipelined<
2480 typename DefaultEpilogue::Shape,
2481 typename Mma::Operator,
2482 DefaultEpilogue::kPartitionsK,
2483 typename MatmulT::OutputTileIterator,
2484 typename DefaultEpilogue::AccumulatorFragmentIterator,
2485 typename DefaultEpilogue::WarpTileIterator,
2486 typename DefaultEpilogue::SharedLoadIterator,
2487 EpilogueOutputOp,
2488 typename DefaultEpilogue::Padding,
2489 DefaultEpilogue::kFragmentsPerIteration,
2490 true // IterationsUnroll
2491 >;
2492 EpilogueOutputOp rescale({1, 1});
2493 Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id);
2494 epilogue(rescale, output_it, accum, output_it);
2495 }));
2496 }
2497
2498 template <int kElementsPerAccess>
computeDeltaAttentionBackwardKernel2499 static CUTLASS_DEVICE void computeDelta(
2500 Params const& p,
2501 int32_t query_start,
2502 uint8_t warp_id,
2503 uint8_t lane_id) {
2504 // Each thread computes one value for Delta
2505 // Depending on warp configuration, we might have multiple
2506 // threads of the same warp working on the same row
2507 using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
2508 static_assert(kNumThreads >= kBlockSizeI, "");
2509 static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
2510 int16_t thread_id = 32 * warp_id + lane_id;
2511
2512 int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine);
2513 int16_t laneRow = thread_id / kNumThreadsPerLine;
2514 bool rowPred = (query_start + laneRow) < p.num_queries;
2515 bool pred = rowPred;
2516
2517 // on windows, previous syntax __restrict__ AccessType*
2518 // resulted in error: "restrict" is not allowed
2519 const AccessType* __restrict__ grad_output_ptr =
2520 reinterpret_cast<const AccessType*>(
2521 p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM +
2522 laneFirstCol);
2523 const AccessType* __restrict__ output_ptr =
2524 reinterpret_cast<const AccessType*>(
2525 p.output_ptr + (query_start + laneRow) * p.o_strideM() +
2526 laneFirstCol);
2527
2528 static constexpr int64_t kMaxIters =
2529 kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
2530 constexpr int kPipelineStages = 2;
2531 accum_t delta_value = accum_t(0);
2532 using GlobalLoad =
2533 cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
2534 AccessType frag_grad_output[kPipelineStages];
2535 AccessType frag_output[kPipelineStages];
2536
2537 auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
2538 frag_grad_output[ld_pos].clear();
2539 frag_output[ld_pos].clear();
2540 GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
2541 GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
2542 grad_output_ptr += kNumThreadsPerLine;
2543 output_ptr += kNumThreadsPerLine;
2544 };
2545
2546 CUTLASS_PRAGMA_UNROLL
2547 for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
2548 int ld_pos = iter % kPipelineStages;
2549 pred = pred &&
2550 (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) <
2551 p.head_dim_value;
2552 loadAndIncrement(ld_pos, pred);
2553 }
2554 auto columnIteration = [&](int iter) {
2555 // Load for next iter
2556 int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
2557 pred = pred &&
2558 (laneFirstCol +
2559 (iter + kPipelineStages - 1) * kElementsPerAccess *
2560 kNumThreadsPerLine) < p.head_dim_value;
2561 loadAndIncrement(ld_pos, pred);
2562 CUTLASS_PRAGMA_UNROLL
2563 for (int i = 0; i < AccessType::kElements; ++i) {
2564 delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
2565 accum_t(frag_grad_output[iter % kPipelineStages][i]);
2566 }
2567 };
2568
2569 // If we have a small lower-bound for K, we can unroll the loop
2570 if (kMaxK <= 256) {
2571 CUTLASS_PRAGMA_UNROLL
2572 for (int iter = 0; iter < kMaxIters; ++iter) {
2573 columnIteration(iter);
2574 }
2575 } else {
2576 int num_iters =
2577 ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
2578 (kElementsPerAccess * kNumThreadsPerLine);
2579 for (int iter = 0; iter < num_iters; ++iter) {
2580 columnIteration(iter);
2581 }
2582 }
2583
2584 // Reduce between workers
2585 static_assert(
2586 kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 ||
2587 kNumThreadsPerLine == 4,
2588 "");
2589 CUTLASS_PRAGMA_UNROLL
2590 for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
2591 delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
2592 }
2593
2594 // Store in gmem
2595 if (rowPred) {
2596 p.delta_ptr[query_start + laneRow] = delta_value;
2597 }
2598 }
2599 };
2600
2601 template <typename AK>
__launch_bounds__(AK::kNumThreads,AK::kMinBlocksPerSm)2602 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
2603 attention_kernel_backward_batched_impl(typename AK::Params p) {
2604 if (!p.advance_to_block()) {
2605 return;
2606 }
2607 AK::attention_kernel(p);
2608 }
2609
2610 template <typename AK>
2611 __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
2612 attention_kernel_backward_batched(typename AK::Params params);
2613
2614 } // namespace PyTorchMemEffAttention
2615