1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h> 11 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h> 12 13 14 namespace cutlass { 15 namespace transform { 16 namespace threadblock { 17 18 template <typename BaseIterator> 19 struct MakeIteratorResidualLast; 20 21 template < 22 typename Shape, 23 typename Element, 24 typename Layout, 25 int AdvanceRank, 26 typename ThreadMap, 27 int AccessSize, 28 bool Gather> 29 struct MakeIteratorResidualLast<PredicatedTileIterator< 30 Shape, 31 Element, 32 Layout, 33 AdvanceRank, 34 ThreadMap, 35 AccessSize, 36 Gather>> { 37 using Iterator = PredicatedTileIteratorResidualLast< 38 Shape, 39 Element, 40 Layout, 41 AdvanceRank, 42 ThreadMap, 43 AccessSize, 44 Gather>; 45 }; 46 47 template < 48 typename Shape, 49 typename Element, 50 typename Layout, 51 int AdvanceRank, 52 typename ThreadMap, 53 typename AccessType, 54 bool Gather> 55 struct MakeIteratorResidualLast<PredicatedTileAccessIterator< 56 Shape, 57 Element, 58 Layout, 59 AdvanceRank, 60 ThreadMap, 61 AccessType, 62 Gather>> { 63 using Iterator = PredicatedTileAccessIteratorResidualLast< 64 Shape, 65 Element, 66 Layout, 67 AdvanceRank, 68 ThreadMap, 69 AccessType, 70 Gather>; 71 }; 72 } // namespace threadblock 73 } // namespace transform 74 } // namespace cutlass 75