1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h> 11 12 template <typename WarpIterator> 13 struct TransposeWarpIterator { 14 using Iterator = char; 15 static bool constexpr kSupportsTranspose = false; 16 }; 17 18 template < 19 /// Operand identity 20 cutlass::gemm::Operand Operand, 21 /// Data type of A elements 22 typename Element, 23 typename InstructionShape, 24 bool kTranspose> 25 struct TransposeWarpIterator< 26 cutlass::gemm::warp:: 27 WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> { 28 using Iterator = cutlass::gemm::warp:: 29 WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>; 30 static bool constexpr kSupportsTranspose = true; 31 }; 32