Home
last modified time | relevance | path

Searched defs:MatmulGradQ (Results 1 – 1 of 1) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/
H A Dkernel_backward.h494 struct MatmulGradQ { struct
496 using ThreadblockShape =
498 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
499 using InstructionShape = typename GemmType::InstructionShape;
501 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
522 using WarpIteratorA = typename cutlass::gemm::threadblock::
528 using DefaultMmaFromSmem =
534 using Mma = typename DefaultMmaFromSmem::Mma;
535 using IteratorB = typename Mma::IteratorB;
536 using WarpCount = typename Mma::WarpCount;
[all …]