1 /* 2 * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 namespace fastertransformer { 20 // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape 21 // in the kernel layout details when doing weight only quantization. 22 enum class CutlassTileConfig { 23 // Signals that we should run heuristics do choose a config 24 Undefined, 25 26 // Signals that we should run heuristics do choose a config 27 ChooseWithHeuristic, 28 29 // SiMT config 30 CtaShape128x128x8_WarpShape64x64x8, 31 32 // TensorCore configs CTA_N = 128, CTA_K = 64 33 // Warp configs for M=32 34 CtaShape32x128x64_WarpShape32x32x64, 35 36 // Warp configs for M=64 37 CtaShape64x128x64_WarpShape32x64x64, 38 CtaShape64x128x64_WarpShape64x32x64, 39 40 // Warp configs for M=128 41 CtaShape128x128x64_WarpShape64x32x64, 42 CtaShape128x128x64_WarpShape128x32x64 43 }; 44 45 enum class SplitKStyle { 46 NO_SPLIT_K, 47 SPLIT_K_SERIAL, 48 // SPLIT_K_PARALLEL // Not supported yet 49 }; 50 51 struct CutlassGemmConfig { 52 CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; 53 SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; 54 int split_k_factor = -1; 55 int stages = -1; 56 }; 57 58 } // namespace fastertransformer