1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11
12 #include <ATen/cuda/tunable/GemmCommon.h>
13 #ifdef USE_ROCM
14 #include <ATen/cuda/tunable/GemmHipblaslt.h>
15 #include <ATen/cuda/tunable/GemmRocblas.h>
16 #endif
17 #include <ATen/cuda/tunable/StreamTimer.h>
18 #include <ATen/cuda/tunable/TunableOp.h>
19 #include <c10/cuda/CUDACachingAllocator.h>
20 #include <c10/util/Float8_e4m3fn.h>
21 #include <c10/util/Float8_e4m3fnuz.h>
22 #include <c10/util/Float8_e5m2.h>
23 #include <c10/util/Float8_e5m2fnuz.h>
24 #include <c10/util/StringUtil.h>
25
26 namespace at::cuda::tunable {
27
28 template <typename T>
29 class DefaultGemmOp : public Callable<GemmParams<T>> {
30 public:
Call(const GemmParams<T> * params)31 TuningStatus Call(const GemmParams<T>* params) override {
32 at::cuda::blas::gemm_internal<T>(
33 params->transa, params->transb,
34 params->m, params->n, params->k,
35 params->alpha,
36 params->a, params->lda,
37 params->b, params->ldb,
38 params->beta,
39 params->c, params->ldc);
40 return OK;
41 }
42 };
43
_transposeBoolFromChar(char op)44 static bool _transposeBoolFromChar(char op) {
45 return op == 't' || op == 'T';
46 }
47
48 template <typename T>
49 class DefaultGemmAndBiasOp : public Callable<GemmAndBiasParams<T>> {
50 public:
Call(const GemmAndBiasParams<T> * params)51 TuningStatus Call(const GemmAndBiasParams<T>* params) override {
52 at::cuda::blas::gemm_and_bias<T>(
53 _transposeBoolFromChar(params->transa),
54 _transposeBoolFromChar(params->transb),
55 params->m, params->n, params->k,
56 params->alpha,
57 params->a, params->lda,
58 params->b, params->ldb,
59 params->bias,
60 params->c, params->ldc,
61 params->activation);
62 return OK;
63 }
64 };
65
66 template <typename T>
67 class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>> {
68 public:
Call(const GemmStridedBatchedParams<T> * params)69 TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
70 at::cuda::blas::bgemm_internal<T>(
71 params->transa, params->transb,
72 params->m, params->n, params->k,
73 params->alpha,
74 params->a, params->lda, params->stride_a,
75 params->b, params->ldb, params->stride_b,
76 params->beta,
77 params->c, params->ldc, params->stride_c,
78 params->batch);
79 return OK;
80 }
81 };
82
83 template <typename T>
84 class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
85 public:
Call(const ScaledGemmParams<T> * params)86 TuningStatus Call(const ScaledGemmParams<T>* params) override {
87 at::cuda::blas::scaled_gemm(
88 params->transa,
89 params->transb,
90 params->m,
91 params->n,
92 params->k,
93 params->a,
94 params->a_scale_ptr,
95 params->lda,
96 params->a_dtype,
97 params->b,
98 params->b_scale_ptr,
99 params->ldb,
100 params->b_dtype,
101 params->bias_ptr,
102 params->bias_dtype,
103 params->c,
104 params->c_scale_ptr,
105 params->ldc,
106 params->c_dtype,
107 params->amax_ptr,
108 params->use_fast_accum);
109 return OK;
110 }
111 };
112
113 template <typename T>
IsZero(T v)114 inline bool IsZero(T v) {
115 return v == 0.0f;
116 }
117
118 template <>
IsZero(BFloat16 v)119 inline bool IsZero(BFloat16 v) {
120 return v.x == 0;
121 }
122
123 template <>
IsZero(Half v)124 inline bool IsZero(Half v) {
125 return float(v) == 0.0f;
126 }
127
128 template <>
IsZero(c10::complex<double> v)129 inline bool IsZero(c10::complex<double> v) {
130 return v == 0.0;
131 }
132
133 template <>
IsZero(c10::complex<float> v)134 inline bool IsZero(c10::complex<float> v) {
135 return v == 0.0f;
136 }
137
138 template <typename T>
TypeName(T v)139 inline std::string TypeName(T v) {
140 return "unknown";
141 }
142
143 template <>
TypeName(float v)144 inline std::string TypeName(float v) {
145 return "float";
146 }
147
148 template <>
TypeName(double v)149 inline std::string TypeName(double v) {
150 return "double";
151 }
152
153 template <>
TypeName(BFloat16 v)154 inline std::string TypeName(BFloat16 v) {
155 return "BFloat16";
156 }
157
158 template <>
TypeName(Half v)159 inline std::string TypeName(Half v) {
160 return "Half";
161 }
162
163 template <>
TypeName(Float8_e4m3fn v)164 inline std::string TypeName(Float8_e4m3fn v) {
165 return "Float8_e4m3fn";
166 }
167
168 template <>
TypeName(Float8_e5m2 v)169 inline std::string TypeName(Float8_e5m2 v) {
170 return "Float8_e5m2";
171 }
172
173 template <>
TypeName(Float8_e4m3fnuz v)174 inline std::string TypeName(Float8_e4m3fnuz v) {
175 return "Float8_e4m3fnuz";
176 }
177
178 template <>
TypeName(Float8_e5m2fnuz v)179 inline std::string TypeName(Float8_e5m2fnuz v) {
180 return "Float8_e5m2fnuz";
181 }
182
183 template <>
TypeName(c10::complex<double> v)184 inline std::string TypeName(c10::complex<double> v) {
185 return "c10::complex<double>";
186 }
187
188 template <>
TypeName(c10::complex<float> v)189 inline std::string TypeName(c10::complex<float> v) {
190 return "c10::complex<float>";
191 }
192
193 template <typename T, BlasOp ALayout, BlasOp BLayout>
194 class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
195 public:
GemmTunableOp()196 GemmTunableOp() {
197 this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmOp<T>>());
198
199 #ifdef USE_ROCM
200 static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
201 if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
202 for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps<T>()) {
203 this->RegisterOp(std::move(name), std::move(op));
204 }
205 }
206
207 static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
208 if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
209 // disallow tuning of hipblaslt with c10::complex
210 if constexpr (
211 !std::is_same_v<T, c10::complex<float>> &&
212 !std::is_same_v<T, c10::complex<double>>) {
213 for (auto&& [name, op] : GetHipBlasLtGemmTypeStringAndOps<T, ALayout, BLayout>()) {
214 this->RegisterOp(std::move(name), std::move(op));
215 }
216 }
217 }
218 #endif
219 }
220
Signature()221 std::string Signature() override {
222 return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
223 }
224 };
225
226 template <typename T, BlasOp ALayout, BlasOp BLayout>
227 class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer> {
228 public:
GemmAndBiasTunableOp()229 GemmAndBiasTunableOp() {
230 this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
231
232 #ifdef USE_ROCM
233 static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
234 if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
235 // disallow tuning of hipblaslt with c10::complex
236 if constexpr (
237 !std::is_same_v<T, c10::complex<float>> &&
238 !std::is_same_v<T, c10::complex<double>>) {
239 for (auto&& [name, op] : GetHipBlasLtGemmAndBiasTypeStringAndOps<T, ALayout, BLayout>()) {
240 this->RegisterOp(std::move(name), std::move(op));
241 }
242 }
243 }
244 #endif
245 }
246
Signature()247 std::string Signature() override {
248 return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
249 }
250 };
251
252 template <typename T, BlasOp ALayout, BlasOp BLayout>
253 class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>, StreamTimer> {
254 public:
GemmStridedBatchedTunableOp()255 GemmStridedBatchedTunableOp() {
256 this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmStridedBatchedOp<T>>());
257
258 #ifdef USE_ROCM
259 static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED");
260 if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) {
261 for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps<T>()) {
262 this->RegisterOp(std::move(name), std::move(op));
263 }
264 }
265
266 static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
267 if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
268 // disallow tuning of hipblaslt with c10::complex
269 if constexpr (
270 !std::is_same_v<T, c10::complex<float>> &&
271 !std::is_same_v<T, c10::complex<double>>) {
272 for (auto&& [name, op] : GetHipBlasLtGemmStridedBatchedTypeStringAndOps<T, ALayout, BLayout>()) {
273 this->RegisterOp(std::move(name), std::move(op));
274 }
275 }
276 }
277 #endif
278 }
279
Signature()280 std::string Signature() override {
281 return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
282 }
283 };
284
285 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
286 class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
287 public:
ScaledGemmTunableOp()288 ScaledGemmTunableOp() {
289 this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
290
291 #ifdef USE_ROCM
292 for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
293 this->RegisterOp(std::move(name), std::move(op));
294 }
295 #endif
296 }
297
Signature()298 std::string Signature() override {
299 return c10::str("ScaledGemmTunableOp",
300 "_", TypeName<AT>(AT{}),
301 "_", TypeName<BT>(BT{}),
302 "_", TypeName<CT>(CT{}),
303 "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
304 }
305 };
306
307 } // namespace at::cuda::tunable
308