1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3
4 #pragma once
5
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDADataType.h>
8 #include <ATen/cuda/tunable/TunableOp.h>
9 #include <ATen/cuda/tunable/GemmCommon.h>
10 #include <c10/cuda/CUDACachingAllocator.h>
11 #include <c10/util/StringUtil.h>
12
13 #include <hipblaslt/hipblaslt.h>
14 #include <hipblaslt/hipblaslt-ext.hpp>
15
16 #define TORCH_HIPBLASLT_CHECK(EXPR) \
17 do { \
18 hipblasStatus_t __err = EXPR; \
19 TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \
20 "hipblaslt error: ", \
21 hipblasStatusToString(__err), \
22 " when calling `" #EXPR "`"); \
23 } while (0)
24
25 namespace at::cuda::tunable {
26
27 template <typename T>
28 constexpr hipblasDatatype_t HipDataTypeFor();
29
30 template <>
31 constexpr hipblasDatatype_t HipDataTypeFor<float>() {
32 return HIP_R_32F;
33 }
34
35 template <>
36 constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
37 return HIP_R_16F;
38 }
39
40 template <>
41 constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
42 return HIP_R_16BF;
43 }
44
45 template <>
46 constexpr hipblasDatatype_t HipDataTypeFor<double>() {
47 return HIP_R_64F;
48 }
49
50 template <>
51 constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
52 return HIP_R_8F_E4M3_FNUZ;
53 }
54
55 template <>
56 constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
57 return HIP_R_8F_E5M2_FNUZ;
58 }
59
60 template <typename T>
GetBatchFromParams(const GemmParams<T> * params)61 int GetBatchFromParams(const GemmParams<T>* params) {
62 return 1;
63 }
64
65 template <typename T>
GetBatchFromParams(const GemmAndBiasParams<T> * params)66 int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
67 return 1;
68 }
69
70 template <typename T>
GetBatchFromParams(const GemmStridedBatchedParams<T> * params)71 int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
72 return params->batch;
73 }
74
75 template <typename T>
GetBatchFromParams(const ScaledGemmParams<T> * params)76 int GetBatchFromParams(const ScaledGemmParams<T>* params) {
77 return 1;
78 }
79
80 template <typename T>
GetStrideAFromParams(const GemmParams<T> * params)81 int GetStrideAFromParams(const GemmParams<T>* params) {
82 return 1;
83 }
84
85 template <typename T>
GetStrideAFromParams(const GemmAndBiasParams<T> * params)86 int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
87 return 1;
88 }
89
90 template <typename T>
GetStrideAFromParams(const GemmStridedBatchedParams<T> * params)91 int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
92 return params->stride_a;
93 }
94
95 template <typename T>
GetStrideAFromParams(const ScaledGemmParams<T> * params)96 int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
97 return 1;
98 }
99
100 template <typename T>
GetStrideBFromParams(const GemmParams<T> * params)101 int GetStrideBFromParams(const GemmParams<T>* params) {
102 return 1;
103 }
104
105 template <typename T>
GetStrideBFromParams(const GemmAndBiasParams<T> * params)106 int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
107 return 1;
108 }
109
110 template <typename T>
GetStrideBFromParams(const GemmStridedBatchedParams<T> * params)111 int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
112 return params->stride_b;
113 }
114
115 template <typename T>
GetStrideBFromParams(const ScaledGemmParams<T> * params)116 int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
117 return 1;
118 }
119
120 template <typename T>
GetStrideCFromParams(const GemmParams<T> * params)121 int GetStrideCFromParams(const GemmParams<T>* params) {
122 return 1;
123 }
124
125 template <typename T>
GetStrideCFromParams(const GemmAndBiasParams<T> * params)126 int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
127 return 1;
128 }
129
130 template <typename T>
GetStrideCFromParams(const GemmStridedBatchedParams<T> * params)131 int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
132 return params->stride_c;
133 }
134
135 template <typename T>
GetStrideCFromParams(const ScaledGemmParams<T> * params)136 int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
137 return 1;
138 }
139
140 template <typename T>
GetAlphaFromParams(const GemmParams<T> * params)141 float GetAlphaFromParams(const GemmParams<T>* params) {
142 return params->alpha;
143 }
144
145 template <typename T>
GetAlphaFromParams(const GemmAndBiasParams<T> * params)146 float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
147 return params->alpha;
148 }
149
150 template <typename T>
GetAlphaFromParams(const GemmStridedBatchedParams<T> * params)151 float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
152 return params->alpha;
153 }
154
155 template <typename T>
GetAlphaFromParams(const ScaledGemmParams<T> * params)156 float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
157 return 1.0;
158 }
159
160 template <typename T>
GetBetaFromParams(const GemmParams<T> * params)161 float GetBetaFromParams(const GemmParams<T>* params) {
162 return params->beta;
163 }
164
165 template <typename T>
GetBetaFromParams(const GemmAndBiasParams<T> * params)166 float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
167 return 0.0;
168 }
169
170 template <typename T>
GetBetaFromParams(const GemmStridedBatchedParams<T> * params)171 float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
172 return params->beta;
173 }
174
175 template <typename T>
GetBetaFromParams(const ScaledGemmParams<T> * params)176 float GetBetaFromParams(const ScaledGemmParams<T>* params) {
177 return 0.0;
178 }
179
180 template <typename T>
GetAScalePointerFromParams(const GemmParams<T> * params)181 const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
182 return nullptr;
183 }
184
185 template <typename T>
GetAScalePointerFromParams(const GemmAndBiasParams<T> * params)186 const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
187 return nullptr;
188 }
189
190 template <typename T>
GetAScalePointerFromParams(const GemmStridedBatchedParams<T> * params)191 const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
192 return nullptr;
193 }
194
195 template <typename T>
GetAScalePointerFromParams(const ScaledGemmParams<T> * params)196 const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
197 return params->a_scale_ptr;
198 }
199
200 template <typename T>
GetBScalePointerFromParams(const GemmParams<T> * params)201 const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
202 return nullptr;
203 }
204
205 template <typename T>
GetBScalePointerFromParams(const GemmAndBiasParams<T> * params)206 const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
207 return nullptr;
208 }
209
210 template <typename T>
GetBScalePointerFromParams(const GemmStridedBatchedParams<T> * params)211 const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
212 return nullptr;
213 }
214
215 template <typename T>
GetBScalePointerFromParams(const ScaledGemmParams<T> * params)216 const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
217 return params->b_scale_ptr;
218 }
219
220 template <typename T>
GetDScalePointerFromParams(const GemmParams<T> * params)221 const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
222 return nullptr;
223 }
224
225 template <typename T>
GetDScalePointerFromParams(const GemmAndBiasParams<T> * params)226 const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
227 return nullptr;
228 }
229
230 template <typename T>
GetDScalePointerFromParams(const GemmStridedBatchedParams<T> * params)231 const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
232 return nullptr;
233 }
234
235 template <typename T>
GetDScalePointerFromParams(const ScaledGemmParams<T> * params)236 const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
237 return params->c_scale_ptr;
238 }
239
240 template <typename T>
GetBiasPointerFromParams(const GemmParams<T> * params)241 const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
242 return nullptr;
243 }
244
245 template <typename T>
GetBiasPointerFromParams(const GemmAndBiasParams<T> * params)246 const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
247 return params->bias;
248 }
249
250 template <typename T>
GetBiasPointerFromParams(const GemmStridedBatchedParams<T> * params)251 const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
252 return nullptr;
253 }
254
255 template <typename T>
GetBiasPointerFromParams(const ScaledGemmParams<T> * params)256 const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
257 return params->bias_ptr;
258 }
259
260 template <typename T>
GetBiasTypeFromParams(const GemmParams<T> * params)261 hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
262 return HIP_R_32F;
263 }
264
265 template <typename T>
GetBiasTypeFromParams(const GemmAndBiasParams<T> * params)266 hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
267 return HipDataTypeFor<T>();
268 }
269
270 template <typename T>
GetBiasTypeFromParams(const GemmStridedBatchedParams<T> * params)271 hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
272 return HIP_R_32F;
273 }
274
275 template <typename T>
GetBiasTypeFromParams(const ScaledGemmParams<T> * params)276 hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
277 return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
278 }
279
280 template <typename T>
GetActivationFromParams(const GemmParams<T> * params)281 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
282 return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
283 }
284
285 template <typename T>
GetActivationFromParams(const GemmAndBiasParams<T> * params)286 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
287 return params->activation;
288 }
289
290 template <typename T>
GetActivationFromParams(const GemmStridedBatchedParams<T> * params)291 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
292 return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
293 }
294
295 template <typename T>
GetActivationFromParams(const ScaledGemmParams<T> * params)296 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
297 return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
298 }
299
_hipblasOpFromChar(char op)300 static hipblasOperation_t _hipblasOpFromChar(char op) {
301 switch (op) {
302 case 'n':
303 case 'N':
304 return HIPBLAS_OP_N;
305 case 't':
306 case 'T':
307 return HIPBLAS_OP_T;
308 case 'c':
309 case 'C':
310 return HIPBLAS_OP_C;
311 }
312 AT_ERROR(
313 "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
314 }
315
_charFromhipblasOp(hipblasOperation_t op)316 static char _charFromhipblasOp(hipblasOperation_t op) {
317 switch (op) {
318 case HIPBLAS_OP_N:
319 return 'N';
320 case HIPBLAS_OP_T:
321 return 'T';
322 case HIPBLAS_OP_C:
323 return 'C';
324 }
325 AT_ERROR(
326 "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
327 }
328
MapLayoutToHipBlasLt(BlasOp layout)329 static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
330 if (layout == BlasOp::N) {
331 return HIPBLAS_OP_N;
332 }
333 return HIPBLAS_OP_T;
334 }
335
GetHipblasltWorkspaceSize()336 static size_t GetHipblasltWorkspaceSize() {
337 static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
338 // 256MB is max workspace size allowed for hipblaslt
339 // hipblaslt-bench uses 32MB
340 // recommendation from hipblaslt author was 76MB
341 size_t workspace_size = 32*1024; // going with 32MB
342 if (env) {
343 try {
344 workspace_size = std::stoi(env);
345 } catch(std::invalid_argument const& e) {
346 TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
347 " using default workspace size of ", workspace_size, " KiB.");
348 } catch(std::out_of_range const& e) {
349 TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
350 " using default workspace size of ", workspace_size, " KiB.");
351 }
352 }
353 return workspace_size * 1024;
354 }
355
356 template <typename T, cublasStatus_t (*destructor)(T*)>
357 struct HipBlasLtDeleter {
operatorHipBlasLtDeleter358 void operator()(T* x) {
359 if (x != nullptr) {
360 TORCH_CUDABLAS_CHECK(destructor(x));
361 }
362 }
363 };
364
365 template <typename T, hipblasStatus_t (*destructor)(T*)>
366 class HipBlasLtDescriptor {
367 public:
descriptor()368 T* descriptor() const {
369 return descriptor_.get();
370 }
descriptor()371 T* descriptor() {
372 return descriptor_.get();
373 }
374
375 protected:
376 std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
377 };
378
379 class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
380 hipblasLtMatmulDescOpaque_t,
381 &hipblasLtMatmulDescDestroy> {
382 public:
HipBlasLtMatmulDescriptor(hipblasComputeType_t compute_type,hipDataType scale_type)383 HipBlasLtMatmulDescriptor(
384 hipblasComputeType_t compute_type,
385 hipDataType scale_type) {
386 hipblasLtMatmulDesc_t raw_descriptor = nullptr;
387 TORCH_HIPBLASLT_CHECK(
388 hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
389 descriptor_.reset(raw_descriptor);
390 }
391 template <typename T>
setAttribute(hipblasLtMatmulDescAttributes_t attr,const T value)392 inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
393 TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
394 }
395 };
396
397 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
398 class HipblasltGemmOp : public Callable<ParamsT> {
399 public:
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo)400 HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
401
Call(const ParamsT * params)402 TuningStatus Call(const ParamsT* params) override {
403 hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
404 hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
405 auto a_datatype = HipDataTypeFor<AT>();
406 auto b_datatype = HipDataTypeFor<BT>();
407 auto in_out_datatype = HipDataTypeFor<CT>();
408 auto opa = _hipblasOpFromChar(params->transa);
409 auto opb = _hipblasOpFromChar(params->transb);
410
411 TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
412
413 float alpha = GetAlphaFromParams<CT>(params);
414 float beta = GetBetaFromParams<CT>(params);
415
416 hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
417 if (opa == HIPBLAS_OP_N) {
418 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
419 }
420 else {
421 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
422 }
423 if (opb == HIPBLAS_OP_N) {
424 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
425 }
426 else {
427 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
428 }
429 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
430
431 // specific to batched gemmm
432 int batch = GetBatchFromParams<CT>(params);
433 if (batch > 1) {
434 int64_t stride_a = GetStrideAFromParams<CT>(params);
435 int64_t stride_b = GetStrideBFromParams<CT>(params);
436 int64_t stride_c = GetStrideCFromParams<CT>(params);
437 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
438 mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
439 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
440 mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
441 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
442 mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
443 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
444 mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
445 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
446 mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
447 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
448 mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
449 }
450
451 HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
452 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
453 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
454
455 // specific to scaled gemm
456 const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
457 const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
458 const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
459 if (mat1_scale_ptr && mat2_scale_ptr) {
460 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
461 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
462 }
463 if (result_scale_ptr) {
464 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
465 }
466
467 const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
468 auto bias_datatype = GetBiasTypeFromParams<CT>(params);
469 if (bias_ptr) {
470 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
471 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
472 auto activation = GetActivationFromParams<CT>(params);
473 if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
474 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
475 }
476 else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
477 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
478 }
479 else {
480 matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
481 }
482 }
483
484 size_t workspace_size = GetHipblasltWorkspaceSize();
485
486 auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
487
488 size_t ret_workspace_size = 0;
489 auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
490 matmul.descriptor(),
491 &alpha,
492 mat_a,
493 mat_b,
494 &beta,
495 mat_c,
496 mat_c,
497 algo_,
498 ret_workspace_size);
499
500 if (status == HIPBLAS_STATUS_SUCCESS) {
501 if (ret_workspace_size >= workspace_size) {
502 return FAIL;
503 }
504 }
505 else {
506 return FAIL;
507 }
508
509 void* workspace_buffer = nullptr;
510 if (workspace_size > 0) {
511 workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
512 }
513
514 TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
515 matmul.descriptor(),
516 &alpha,
517 params->a,
518 mat_a,
519 params->b,
520 mat_b,
521 &beta,
522 params->c,
523 mat_c,
524 params->c,
525 mat_c,
526 &algo_,
527 workspace_buffer,
528 workspace_size,
529 at::cuda::getCurrentCUDAStream()));
530
531 //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
532 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
533 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
534 TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
535 if (workspace_size > 0) {
536 c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
537 }
538 return OK;
539 }
540
541 private:
542 hipblasLtMatmulAlgo_t algo_;
543 };
544
545 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
GetHipBlasLtTypeStringAndOps()546 auto GetHipBlasLtTypeStringAndOps() {
547 hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
548 hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
549 auto a_datatype = HipDataTypeFor<AT>();
550 auto b_datatype = HipDataTypeFor<BT>();
551 auto in_out_datatype = HipDataTypeFor<CT>();
552 std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
553
554 hipblasLtHandle_t handle;
555 TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
556 TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
557 hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
558 transa_outer,
559 transb_outer,
560 a_datatype,
561 b_datatype,
562 in_out_datatype,
563 in_out_datatype,
564 HIPBLAS_COMPUTE_32F,
565 heuristic_result));
566 TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
567
568 // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
569 std::sort(heuristic_result.begin(),
570 heuristic_result.end(),
571 [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
572 return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
573 });
574
575 int returned_algo_count = heuristic_result.size();
576 std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
577 for (int i = 0; i < returned_algo_count; i++) {
578 auto algo = heuristic_result[i].algo;
579 int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
580 auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
581 std::string type_string = c10::str(
582 "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
583 ret.emplace_back(type_string, std::move(callable));
584 }
585
586 return ret;
587 }
588
589 template <typename T, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtGemmTypeStringAndOps()590 auto GetHipBlasLtGemmTypeStringAndOps() {
591 return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
592 }
593
594 template <typename T, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtGemmAndBiasTypeStringAndOps()595 auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
596 return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
597 }
598
599 template <typename T, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtGemmStridedBatchedTypeStringAndOps()600 auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
601 return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
602 }
603
604 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtScaledGemmTypeStringAndOps()605 auto GetHipBlasLtScaledGemmTypeStringAndOps() {
606 return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
607 }
608
609 #undef TORCH_HIPBLASLT_CHECK
610
611 } // namespace at::cuda::tunable
612