1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h>
11 
12 template <typename WarpIterator>
13 struct TransposeWarpIterator {
14   using Iterator = char;
15   static bool constexpr kSupportsTranspose = false;
16 };
17 
18 template <
19     /// Operand identity
20     cutlass::gemm::Operand Operand,
21     /// Data type of A elements
22     typename Element,
23     typename InstructionShape,
24     bool kTranspose>
25 struct TransposeWarpIterator<
26     cutlass::gemm::warp::
27         WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
28   using Iterator = cutlass::gemm::warp::
29       WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
30   static bool constexpr kSupportsTranspose = true;
31 };
32