1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h> 11 #include <ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h> 12 13 #include <cutlass/gemm/threadblock/mma_multistage.h> 14 #include <cutlass/gemm/threadblock/mma_pipelined.h> 15 template <typename Mma, int kMaxK> 16 struct MakeCustomMma; 17 18 template < 19 typename Shape, 20 typename IteratorA, 21 typename SmemIteratorA, 22 cutlass::arch::CacheOperation::Kind CacheOpA, 23 typename IteratorB, 24 typename SmemIteratorB, 25 cutlass::arch::CacheOperation::Kind CacheOpB, 26 typename ElementC, 27 typename LayoutC, 28 typename Policy, 29 int Stages, 30 cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, 31 int kMaxK> 32 struct MakeCustomMma< 33 cutlass::gemm::threadblock::MmaMultistage< 34 Shape, 35 IteratorA, 36 SmemIteratorA, 37 CacheOpA, 38 IteratorB, 39 SmemIteratorB, 40 CacheOpB, 41 ElementC, 42 LayoutC, 43 Policy, 44 Stages, 45 SharedMemoryClear>, 46 kMaxK> { 47 // Reduce the number of stages if we don't need that many 48 static int constexpr kStages = 49 kMaxK == cutlass::platform::numeric_limits<int>::max() 50 ? Stages 51 : cutlass::const_min( 52 Stages, 53 (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); 54 using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< 55 Shape, 56 IteratorA, 57 SmemIteratorA, 58 CacheOpA, 59 IteratorB, 60 SmemIteratorB, 61 CacheOpB, 62 ElementC, 63 LayoutC, 64 Policy, 65 kStages, 66 SharedMemoryClear, 67 kMaxK>; 68 }; 69 70 template < 71 typename Shape, 72 typename IteratorA, 73 typename SmemIteratorA, 74 typename IteratorB, 75 typename SmemIteratorB, 76 typename ElementC, 77 typename LayoutC, 78 typename Policy, 79 int kMaxK> 80 struct MakeCustomMma< 81 cutlass::gemm::threadblock::MmaPipelined< 82 Shape, 83 IteratorA, 84 SmemIteratorA, 85 IteratorB, 86 SmemIteratorB, 87 ElementC, 88 LayoutC, 89 Policy>, 90 kMaxK> { 91 using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< 92 Shape, 93 IteratorA, 94 SmemIteratorA, 95 IteratorB, 96 SmemIteratorB, 97 ElementC, 98 LayoutC, 99 Policy>; 100 }; 101