Home
last modified time | relevance | path

Searched defs:MatmulGradV (Results 1 – 1 of 1) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/
H A Dkernel_backward.h354 struct MatmulGradV { struct
361 using ThreadblockShape =
363 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
364 using InstructionShape = typename GemmType::InstructionShape;
366 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
397 using WarpIteratorA = typename cutlass::gemm::threadblock::
406 using DefaultMmaFromSmem =
413 using Mma = typename DefaultMmaFromSmem::Mma;
414 using IteratorB = typename Mma::IteratorB;
415 using WarpCount = typename Mma::WarpCount;
[all …]