xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/GemmHipblaslt.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDADataType.h>
8 #include <ATen/cuda/tunable/TunableOp.h>
9 #include <ATen/cuda/tunable/GemmCommon.h>
10 #include <c10/cuda/CUDACachingAllocator.h>
11 #include <c10/util/StringUtil.h>
12 
13 #include <hipblaslt/hipblaslt.h>
14 #include <hipblaslt/hipblaslt-ext.hpp>
15 
16 #define TORCH_HIPBLASLT_CHECK(EXPR)               \
17   do {                                            \
18     hipblasStatus_t __err = EXPR;                 \
19     TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS,  \
20                 "hipblaslt error: ",              \
21                 hipblasStatusToString(__err),     \
22                 " when calling `" #EXPR "`");     \
23   } while (0)
24 
25 namespace at::cuda::tunable {
26 
27 template <typename T>
28 constexpr hipblasDatatype_t HipDataTypeFor();
29 
30 template <>
31 constexpr hipblasDatatype_t HipDataTypeFor<float>() {
32   return HIP_R_32F;
33 }
34 
35 template <>
36 constexpr hipblasDatatype_t HipDataTypeFor<Half>() {
37   return HIP_R_16F;
38 }
39 
40 template <>
41 constexpr hipblasDatatype_t HipDataTypeFor<BFloat16>() {
42   return HIP_R_16BF;
43 }
44 
45 template <>
46 constexpr hipblasDatatype_t HipDataTypeFor<double>() {
47   return HIP_R_64F;
48 }
49 
50 template <>
51 constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e4m3fnuz>() {
52   return HIP_R_8F_E4M3_FNUZ;
53 }
54 
55 template <>
56 constexpr hipblasDatatype_t HipDataTypeFor<c10::Float8_e5m2fnuz>() {
57   return HIP_R_8F_E5M2_FNUZ;
58 }
59 
60 template <typename T>
GetBatchFromParams(const GemmParams<T> * params)61 int GetBatchFromParams(const GemmParams<T>* params) {
62   return 1;
63 }
64 
65 template <typename T>
GetBatchFromParams(const GemmAndBiasParams<T> * params)66 int GetBatchFromParams(const GemmAndBiasParams<T>* params) {
67   return 1;
68 }
69 
70 template <typename T>
GetBatchFromParams(const GemmStridedBatchedParams<T> * params)71 int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
72   return params->batch;
73 }
74 
75 template <typename T>
GetBatchFromParams(const ScaledGemmParams<T> * params)76 int GetBatchFromParams(const ScaledGemmParams<T>* params) {
77   return 1;
78 }
79 
80 template <typename T>
GetStrideAFromParams(const GemmParams<T> * params)81 int GetStrideAFromParams(const GemmParams<T>* params) {
82   return 1;
83 }
84 
85 template <typename T>
GetStrideAFromParams(const GemmAndBiasParams<T> * params)86 int GetStrideAFromParams(const GemmAndBiasParams<T>* params) {
87   return 1;
88 }
89 
90 template <typename T>
GetStrideAFromParams(const GemmStridedBatchedParams<T> * params)91 int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
92   return params->stride_a;
93 }
94 
95 template <typename T>
GetStrideAFromParams(const ScaledGemmParams<T> * params)96 int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
97   return 1;
98 }
99 
100 template <typename T>
GetStrideBFromParams(const GemmParams<T> * params)101 int GetStrideBFromParams(const GemmParams<T>* params) {
102   return 1;
103 }
104 
105 template <typename T>
GetStrideBFromParams(const GemmAndBiasParams<T> * params)106 int GetStrideBFromParams(const GemmAndBiasParams<T>* params) {
107   return 1;
108 }
109 
110 template <typename T>
GetStrideBFromParams(const GemmStridedBatchedParams<T> * params)111 int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
112   return params->stride_b;
113 }
114 
115 template <typename T>
GetStrideBFromParams(const ScaledGemmParams<T> * params)116 int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
117   return 1;
118 }
119 
120 template <typename T>
GetStrideCFromParams(const GemmParams<T> * params)121 int GetStrideCFromParams(const GemmParams<T>* params) {
122   return 1;
123 }
124 
125 template <typename T>
GetStrideCFromParams(const GemmAndBiasParams<T> * params)126 int GetStrideCFromParams(const GemmAndBiasParams<T>* params) {
127   return 1;
128 }
129 
130 template <typename T>
GetStrideCFromParams(const GemmStridedBatchedParams<T> * params)131 int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
132   return params->stride_c;
133 }
134 
135 template <typename T>
GetStrideCFromParams(const ScaledGemmParams<T> * params)136 int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
137   return 1;
138 }
139 
140 template <typename T>
GetAlphaFromParams(const GemmParams<T> * params)141 float GetAlphaFromParams(const GemmParams<T>* params) {
142   return params->alpha;
143 }
144 
145 template <typename T>
GetAlphaFromParams(const GemmAndBiasParams<T> * params)146 float GetAlphaFromParams(const GemmAndBiasParams<T>* params) {
147   return params->alpha;
148 }
149 
150 template <typename T>
GetAlphaFromParams(const GemmStridedBatchedParams<T> * params)151 float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
152   return params->alpha;
153 }
154 
155 template <typename T>
GetAlphaFromParams(const ScaledGemmParams<T> * params)156 float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
157   return 1.0;
158 }
159 
160 template <typename T>
GetBetaFromParams(const GemmParams<T> * params)161 float GetBetaFromParams(const GemmParams<T>* params) {
162   return params->beta;
163 }
164 
165 template <typename T>
GetBetaFromParams(const GemmAndBiasParams<T> * params)166 float GetBetaFromParams(const GemmAndBiasParams<T>* params) {
167   return 0.0;
168 }
169 
170 template <typename T>
GetBetaFromParams(const GemmStridedBatchedParams<T> * params)171 float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
172   return params->beta;
173 }
174 
175 template <typename T>
GetBetaFromParams(const ScaledGemmParams<T> * params)176 float GetBetaFromParams(const ScaledGemmParams<T>* params) {
177   return 0.0;
178 }
179 
180 template <typename T>
GetAScalePointerFromParams(const GemmParams<T> * params)181 const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
182   return nullptr;
183 }
184 
185 template <typename T>
GetAScalePointerFromParams(const GemmAndBiasParams<T> * params)186 const void* GetAScalePointerFromParams(const GemmAndBiasParams<T>* params) {
187   return nullptr;
188 }
189 
190 template <typename T>
GetAScalePointerFromParams(const GemmStridedBatchedParams<T> * params)191 const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
192   return nullptr;
193 }
194 
195 template <typename T>
GetAScalePointerFromParams(const ScaledGemmParams<T> * params)196 const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
197   return params->a_scale_ptr;
198 }
199 
200 template <typename T>
GetBScalePointerFromParams(const GemmParams<T> * params)201 const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
202   return nullptr;
203 }
204 
205 template <typename T>
GetBScalePointerFromParams(const GemmAndBiasParams<T> * params)206 const void* GetBScalePointerFromParams(const GemmAndBiasParams<T>* params) {
207   return nullptr;
208 }
209 
210 template <typename T>
GetBScalePointerFromParams(const GemmStridedBatchedParams<T> * params)211 const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
212   return nullptr;
213 }
214 
215 template <typename T>
GetBScalePointerFromParams(const ScaledGemmParams<T> * params)216 const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
217   return params->b_scale_ptr;
218 }
219 
220 template <typename T>
GetDScalePointerFromParams(const GemmParams<T> * params)221 const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
222   return nullptr;
223 }
224 
225 template <typename T>
GetDScalePointerFromParams(const GemmAndBiasParams<T> * params)226 const void* GetDScalePointerFromParams(const GemmAndBiasParams<T>* params) {
227   return nullptr;
228 }
229 
230 template <typename T>
GetDScalePointerFromParams(const GemmStridedBatchedParams<T> * params)231 const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
232   return nullptr;
233 }
234 
235 template <typename T>
GetDScalePointerFromParams(const ScaledGemmParams<T> * params)236 const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
237   return params->c_scale_ptr;
238 }
239 
240 template <typename T>
GetBiasPointerFromParams(const GemmParams<T> * params)241 const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
242   return nullptr;
243 }
244 
245 template <typename T>
GetBiasPointerFromParams(const GemmAndBiasParams<T> * params)246 const void* GetBiasPointerFromParams(const GemmAndBiasParams<T>* params) {
247   return params->bias;
248 }
249 
250 template <typename T>
GetBiasPointerFromParams(const GemmStridedBatchedParams<T> * params)251 const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
252   return nullptr;
253 }
254 
255 template <typename T>
GetBiasPointerFromParams(const ScaledGemmParams<T> * params)256 const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
257   return params->bias_ptr;
258 }
259 
260 template <typename T>
GetBiasTypeFromParams(const GemmParams<T> * params)261 hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
262   return HIP_R_32F;
263 }
264 
265 template <typename T>
GetBiasTypeFromParams(const GemmAndBiasParams<T> * params)266 hipDataType GetBiasTypeFromParams(const GemmAndBiasParams<T>* params) {
267   return HipDataTypeFor<T>();
268 }
269 
270 template <typename T>
GetBiasTypeFromParams(const GemmStridedBatchedParams<T> * params)271 hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
272   return HIP_R_32F;
273 }
274 
275 template <typename T>
GetBiasTypeFromParams(const ScaledGemmParams<T> * params)276 hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
277   return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
278 }
279 
280 template <typename T>
GetActivationFromParams(const GemmParams<T> * params)281 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams<T>* params) {
282   return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
283 }
284 
285 template <typename T>
GetActivationFromParams(const GemmAndBiasParams<T> * params)286 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams<T>* params) {
287   return params->activation;
288 }
289 
290 template <typename T>
GetActivationFromParams(const GemmStridedBatchedParams<T> * params)291 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams<T>* params) {
292   return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
293 }
294 
295 template <typename T>
GetActivationFromParams(const ScaledGemmParams<T> * params)296 at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams<T>* params) {
297   return at::cuda::blas::GEMMAndBiasActivationEpilogue::None;
298 }
299 
_hipblasOpFromChar(char op)300 static hipblasOperation_t _hipblasOpFromChar(char op) {
301   switch (op) {
302     case 'n':
303     case 'N':
304       return HIPBLAS_OP_N;
305     case 't':
306     case 'T':
307       return HIPBLAS_OP_T;
308     case 'c':
309     case 'C':
310       return HIPBLAS_OP_C;
311   }
312   AT_ERROR(
313       "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
314 }
315 
_charFromhipblasOp(hipblasOperation_t op)316 static char _charFromhipblasOp(hipblasOperation_t op) {
317   switch (op) {
318     case HIPBLAS_OP_N:
319       return 'N';
320     case HIPBLAS_OP_T:
321       return 'T';
322     case HIPBLAS_OP_C:
323       return 'C';
324   }
325   AT_ERROR(
326       "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`");
327 }
328 
MapLayoutToHipBlasLt(BlasOp layout)329 static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) {
330   if (layout == BlasOp::N) {
331     return HIPBLAS_OP_N;
332   }
333   return HIPBLAS_OP_T;
334 }
335 
GetHipblasltWorkspaceSize()336 static size_t GetHipblasltWorkspaceSize() {
337   static const char * env = getenv("HIPBLASLT_WORKSPACE_SIZE");
338   // 256MB is max workspace size allowed for hipblaslt
339   // hipblaslt-bench uses 32MB
340   // recommendation from hipblaslt author was 76MB
341   size_t workspace_size = 32*1024;  // going with 32MB
342   if (env) {
343     try {
344       workspace_size = std::stoi(env);
345     } catch(std::invalid_argument const& e) {
346       TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
347                  " using default workspace size of ", workspace_size, " KiB.");
348     } catch(std::out_of_range const& e) {
349       TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
350                  " using default workspace size of ", workspace_size, " KiB.");
351     }
352   }
353   return workspace_size * 1024;
354 }
355 
356 template <typename T, cublasStatus_t (*destructor)(T*)>
357 struct HipBlasLtDeleter {
operatorHipBlasLtDeleter358   void operator()(T* x) {
359     if (x != nullptr) {
360       TORCH_CUDABLAS_CHECK(destructor(x));
361     }
362   }
363 };
364 
365 template <typename T, hipblasStatus_t (*destructor)(T*)>
366 class HipBlasLtDescriptor {
367  public:
descriptor()368   T* descriptor() const {
369     return descriptor_.get();
370   }
descriptor()371   T* descriptor() {
372     return descriptor_.get();
373   }
374 
375  protected:
376   std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
377 };
378 
379 class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
380                                      hipblasLtMatmulDescOpaque_t,
381                                      &hipblasLtMatmulDescDestroy> {
382  public:
HipBlasLtMatmulDescriptor(hipblasComputeType_t compute_type,hipDataType scale_type)383   HipBlasLtMatmulDescriptor(
384       hipblasComputeType_t compute_type,
385       hipDataType scale_type) {
386     hipblasLtMatmulDesc_t raw_descriptor = nullptr;
387     TORCH_HIPBLASLT_CHECK(
388         hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
389     descriptor_.reset(raw_descriptor);
390   }
391   template <typename T>
setAttribute(hipblasLtMatmulDescAttributes_t attr,const T value)392   inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
393     TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
394   }
395 };
396 
397 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
398 class HipblasltGemmOp : public Callable<ParamsT> {
399   public:
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo)400     HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
401 
Call(const ParamsT * params)402     TuningStatus Call(const ParamsT* params) override {
403       hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
404       hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
405       auto a_datatype = HipDataTypeFor<AT>();
406       auto b_datatype = HipDataTypeFor<BT>();
407       auto in_out_datatype = HipDataTypeFor<CT>();
408       auto opa = _hipblasOpFromChar(params->transa);
409       auto opb = _hipblasOpFromChar(params->transb);
410 
411       TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
412 
413       float alpha = GetAlphaFromParams<CT>(params);
414       float beta = GetBetaFromParams<CT>(params);
415 
416       hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
417       if (opa == HIPBLAS_OP_N) {
418         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
419       }
420       else {
421         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
422       }
423       if (opb == HIPBLAS_OP_N) {
424         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
425       }
426       else {
427         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
428       }
429       TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
430 
431       // specific to batched gemmm
432       int batch = GetBatchFromParams<CT>(params);
433       if (batch > 1) {
434         int64_t stride_a = GetStrideAFromParams<CT>(params);
435         int64_t stride_b = GetStrideBFromParams<CT>(params);
436         int64_t stride_c = GetStrideCFromParams<CT>(params);
437         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
438             mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
439         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
440             mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a)));
441         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
442             mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
443         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
444             mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b)));
445         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
446             mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
447         TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
448             mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
449       }
450 
451       HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
452       matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
453       matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
454 
455       // specific to scaled gemm
456       const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
457       const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
458       const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
459       if (mat1_scale_ptr && mat2_scale_ptr) {
460         matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
461         matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
462       }
463       if (result_scale_ptr) {
464         matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
465       }
466 
467       const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
468       auto bias_datatype = GetBiasTypeFromParams<CT>(params);
469       if (bias_ptr) {
470         matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
471         matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
472         auto activation = GetActivationFromParams<CT>(params);
473         if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) {
474           matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS);
475         }
476         else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) {
477           matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS);
478         }
479         else {
480           matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
481         }
482       }
483 
484       size_t workspace_size = GetHipblasltWorkspaceSize();
485 
486       auto op_handle = at::cuda::getCurrentCUDABlasLtHandle();
487 
488       size_t ret_workspace_size = 0;
489       auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
490           matmul.descriptor(),
491           &alpha,
492           mat_a,
493           mat_b,
494           &beta,
495           mat_c,
496           mat_c,
497           algo_,
498           ret_workspace_size);
499 
500       if (status == HIPBLAS_STATUS_SUCCESS) {
501         if (ret_workspace_size >= workspace_size) {
502           return FAIL;
503         }
504       }
505       else {
506         return FAIL;
507       }
508 
509       void* workspace_buffer = nullptr;
510       if (workspace_size > 0) {
511         workspace_buffer = c10::cuda::CUDACachingAllocator::raw_alloc(workspace_size);
512       }
513 
514       TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
515             matmul.descriptor(),
516             &alpha,
517             params->a,
518             mat_a,
519             params->b,
520             mat_b,
521             &beta,
522             params->c,
523             mat_c,
524             params->c,
525             mat_c,
526             &algo_,
527             workspace_buffer,
528             workspace_size,
529             at::cuda::getCurrentCUDAStream()));
530 
531       //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
532       TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
533       TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
534       TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
535       if (workspace_size > 0) {
536         c10::cuda::CUDACachingAllocator::raw_delete(workspace_buffer);
537       }
538       return OK;
539     }
540 
541   private:
542     hipblasLtMatmulAlgo_t algo_;
543 };
544 
545 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
GetHipBlasLtTypeStringAndOps()546 auto GetHipBlasLtTypeStringAndOps() {
547   hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
548   hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
549   auto a_datatype = HipDataTypeFor<AT>();
550   auto b_datatype = HipDataTypeFor<BT>();
551   auto in_out_datatype = HipDataTypeFor<CT>();
552   std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
553 
554   hipblasLtHandle_t handle;
555   TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
556   TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
557         hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
558         transa_outer,
559         transb_outer,
560         a_datatype,
561         b_datatype,
562         in_out_datatype,
563         in_out_datatype,
564         HIPBLAS_COMPUTE_32F,
565         heuristic_result));
566   TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
567 
568   // Sort heuristic_result by algo index to make sure the order of returned algos is deterministic.
569   std::sort(heuristic_result.begin(),
570       heuristic_result.end(),
571       [](hipblasLtMatmulHeuristicResult_t& a, hipblasLtMatmulHeuristicResult_t& b) {
572       return hipblaslt_ext::getIndexFromAlgo(a.algo) < hipblaslt_ext::getIndexFromAlgo(b.algo);
573       });
574 
575   int returned_algo_count = heuristic_result.size();
576   std::vector<std::pair<std::string, std::unique_ptr<Callable<ParamsT>>>> ret;
577   for (int i = 0; i < returned_algo_count; i++) {
578     auto algo = heuristic_result[i].algo;
579     int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
580     auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
581     std::string type_string = c10::str(
582         "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
583     ret.emplace_back(type_string, std::move(callable));
584   }
585 
586   return ret;
587 }
588 
589 template <typename T, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtGemmTypeStringAndOps()590 auto GetHipBlasLtGemmTypeStringAndOps() {
591   return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
592 }
593 
594 template <typename T, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtGemmAndBiasTypeStringAndOps()595 auto GetHipBlasLtGemmAndBiasTypeStringAndOps() {
596   return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmAndBiasParams<T>>();
597 }
598 
599 template <typename T, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtGemmStridedBatchedTypeStringAndOps()600 auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
601   return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
602 }
603 
604 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
GetHipBlasLtScaledGemmTypeStringAndOps()605 auto GetHipBlasLtScaledGemmTypeStringAndOps() {
606   return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
607 }
608 
609 #undef TORCH_HIPBLASLT_CHECK
610 
611 }  // namespace at::cuda::tunable
612