xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/cutlass_extensions/ft_gemm_configs.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 namespace fastertransformer {
20 // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
21 //       in the kernel layout details when doing weight only quantization.
22 enum class CutlassTileConfig {
23     // Signals that we should run heuristics do choose a config
24     Undefined,
25 
26     // Signals that we should run heuristics do choose a config
27     ChooseWithHeuristic,
28 
29     // SiMT config
30     CtaShape128x128x8_WarpShape64x64x8,
31 
32     // TensorCore configs CTA_N = 128, CTA_K = 64
33     // Warp configs for M=32
34     CtaShape32x128x64_WarpShape32x32x64,
35 
36     // Warp configs for M=64
37     CtaShape64x128x64_WarpShape32x64x64,
38     CtaShape64x128x64_WarpShape64x32x64,
39 
40     // Warp configs for M=128
41     CtaShape128x128x64_WarpShape64x32x64,
42     CtaShape128x128x64_WarpShape128x32x64
43 };
44 
45 enum class SplitKStyle {
46     NO_SPLIT_K,
47     SPLIT_K_SERIAL,
48     // SPLIT_K_PARALLEL // Not supported yet
49 };
50 
51 struct CutlassGemmConfig {
52     CutlassTileConfig tile_config    = CutlassTileConfig::ChooseWithHeuristic;
53     SplitKStyle       split_k_style  = SplitKStyle::NO_SPLIT_K;
54     int               split_k_factor = -1;
55     int               stages         = -1;
56 };
57 
58 }  // namespace fastertransformer