1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 * SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 * 30 **************************************************************************************************/ 31 /*! \file 32 \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. 33 */ 34 35 #pragma once 36 37 #include <cutlass/cutlass.h> 38 #include <cutlass/gemm/warp/default_mma_tensor_op.h> 39 #include <cutlass/gemm/warp/mma_tensor_op.h> 40 41 #include <ATen/native/cuda/cutlass_extensions/arch/mma.h> 42 #include <ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h> 43 44 namespace cutlass { 45 namespace gemm { 46 namespace warp { 47 48 ///////////////////////////////////////////////////////////////////////////////////////////////// 49 50 /// Partial specialization for m-by-n-by-kgroup 51 template< 52 /// Shape of one matrix production operation (concept: GemmShape) 53 typename WarpShape_, 54 /// Shape of one matrix production operation (concept: GemmShape) 55 typename InstructionShape_, 56 /// Data type of A elements, 57 typename ElementA, 58 /// Layout of A matrix (concept: MatrixLayout) 59 typename LayoutA, 60 /// Data type of B elements 61 typename ElementB, 62 /// Layout of B matrix (concept: MatrixLayout) 63 typename LayoutB, 64 /// Element type of C matrix 65 typename ElementC, 66 /// Layout of C matrix (concept: MatrixLayout) 67 typename LayoutC, 68 /// Number of partitions along K dimension 69 int PartitionsK, 70 /// Store the accumulators in row major or column major. Row major is used 71 /// when output layout is interleaved. 72 bool AccumulatorsInRowMajor> 73 struct DefaultMmaTensorOp<WarpShape_, 74 InstructionShape_, 75 ElementA, 76 LayoutA, 77 ElementB, 78 LayoutB, 79 ElementC, 80 LayoutC, 81 arch::OpMultiplyAddDequantizeInterleavedBToA, 82 PartitionsK, 83 AccumulatorsInRowMajor> { 84 85 private: 86 // Shape for computing the FP16s 87 using ComputeInstructionShape = InstructionShape_; 88 89 // Chosen so we get K=16 for int8 and K=32 for int4. 90 static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value; 91 92 // Shape for loading the narrow data type from shared memory 93 using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>; 94 95 public: 96 using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<InstructionShape_, 97 32, 98 ElementA, 99 cutlass::layout::RowMajor, 100 ElementA, 101 cutlass::layout::ColumnMajor, 102 ElementC, 103 cutlass::layout::RowMajor, 104 arch::OpMultiplyAdd>, 105 cutlass::MatrixShape<1, 1>>; 106 107 // Define the warp-level tensor op 108 using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_, 109 ElementA, 110 LayoutA, 111 ElementB, 112 LayoutB, 113 ElementC, 114 LayoutC, 115 Policy, 116 LoadInstructionShape, 117 PartitionsK, 118 AccumulatorsInRowMajor>; 119 }; 120 121 ///////////////////////////////////////////////////////////////////////////////////////////////// 122 123 } // namespace warp 124 } // namespace gemm 125 } // namespace cutlass 126 127 ///////////////////////////////////////////////////////////////////////////////////////////////// 128