xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "third_party/gpus/cuda/include/cublas_v2.h"
17 #include "third_party/gpus/cuda/include/cuda.h"
18 
19 #define SE_CUDA_DATA_HALF CUDA_R_16F
20 
21 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h"
22 
23 // Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As
24 // such, there are two ways to get the typedef for __half:
25 //
26 // (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16.
27 // (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16.
28 //
29 // Due to issue b/73793421, when the first approach is used and NVCC is used to
30 // compile this file, NVCC will complain duplicated definition for
31 // EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and
32 // clang is used to compile this file, clang will not understand __half
33 // due to missing the definition and macro EIGEN_HAS_CUDA_FP16.
34 //
35 // Because this file may be compiled with CLANG but will never be compiled with
36 // NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have
37 // to use the second approach because the data member in the __half defined
38 // by CUDA > 9.0 is `__x` while Eigen expects it to be `x`.
39 //
40 // TODO(b/73793421): Remove the following code block to switch to the second
41 // approach when the issue is fixed.
42 #if CUDA_VERSION < 9000
43 #include "third_party/gpus/cuda/include/cuda_fp16.h"
44 #define EIGEN_HAS_CUDA_FP16
45 #endif
46 
47 #include <complex>
48 
49 #include "absl/strings/str_cat.h"
50 #include "absl/strings/str_format.h"
51 #include "third_party/eigen3/Eigen/Core"
52 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
53 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h"
54 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.h"
55 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_helpers.h"
56 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h"
57 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h"
58 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_timer.h"
59 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
60 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_executor.h"
61 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h"
62 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
63 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h"
64 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_types.h"
65 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
66 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
67 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
68 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
69 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
70 #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h"
71 #include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
72 #include "tensorflow/core/platform/tensor_float_32_utils.h"
73 
74 namespace stream_executor {
75 namespace cuda {
76 
77 using gpu::AsGpuStream;
78 using gpu::AsGpuStreamValue;
79 using gpu::GpuComplex;
80 using gpu::GpuComplexT;
81 using gpu::GpuComplexType;
82 using gpu::GpuComplexValue;
83 using gpu::GpuDoubleComplexType;
84 using gpu::GpuExecutor;
85 using gpu::GpuMemory;
86 using gpu::GpuMemoryMutable;
87 using gpu::GpuTimer;
88 using gpu::GpuTimerDeleter;
89 
90 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
91 
92 // cuBLAS has interfaces that permit pointers to be passed from either the host
93 // memory space or the device memory space; however, you must instruct it as to
94 // which address space those pointers are in with cublasSetPointerMode.
95 //
96 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
97 // you are about to perform in a given scope.
98 //
99 // The prior cuBLAS pointer mode is retained and restored when this object goes
100 // out of scope.
101 class ScopedCublasPointerMode {
102  public:
103   // Note that, because the setting of the cublas pointer mode is fallible,
104   // construction of this scoped datatype must be paired with a call to
105   // Init().
106   //
107   // Parameters:
108   //  handle: The cublas library handle to act upon in setting the pointer mode.
ScopedCublasPointerMode(cublasHandle_t handle)109   explicit ScopedCublasPointerMode(cublasHandle_t handle)
110       : handle_(handle), ok_(false) {}
111 
112   // Attempts the switch to the requested scoped pointer mode, new_mode.
113   //
114   // Note that when false is returned, an appropriate error has already been
115   // logged.
Init(cublasPointerMode_t new_mode)116   bool Init(cublasPointerMode_t new_mode) {
117     cublasStatus_t ret = cublasGetPointerMode(handle_, &old_mode_);
118     if (ret != CUBLAS_STATUS_SUCCESS) {
119       LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
120       return ok_ = false;
121     }
122 
123     ret = cublasSetPointerMode(handle_, new_mode);
124     if (ret != CUBLAS_STATUS_SUCCESS) {
125       LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
126       return ok_ = false;
127     }
128 
129     return ok_ = true;
130   }
131 
132   // Switches back to the prior pointer mode, if the switch operation was
133   // successful in the first place.
~ScopedCublasPointerMode()134   ~ScopedCublasPointerMode() {
135     if (ok_) {
136       cublasStatus_t ret = cublasSetPointerMode(handle_, old_mode_);
137       if (ret != CUBLAS_STATUS_SUCCESS) {
138         LOG(ERROR) << "failed to set former cublas pointer mode: "
139                    << ToString(ret);
140       }
141     }
142   }
143 
144  private:
145   cublasHandle_t handle_;         // Handle to the cuBLAS instance of interest.
146   cublasPointerMode_t old_mode_;  // Prior cuBLAS pointer mode, to be restored.
147   bool ok_;                       // Whether the change was successful.
148 };
149 
150 #if CUDA_VERSION >= 9000
151 // cuBLAS has interfaces that permit computations to use the Volta hardware.
152 // This must be enabled via the cublasGet/SetMathMode APIs.
153 //
154 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
155 // you are about to perform in a given scope.
156 //
157 // The prior cuBLAS math mode is retained and restored when this object goes
158 // out of scope.
159 class ScopedCublasMathMode {
160  public:
161   // Note that, because the setting of the cublas math mode is fallible,
162   // construction of this scoped datatype must be paired with a call to
163   // Init().
164   //
165   // Parameters:
166   //  handle: The cublas library handle to act upon in setting the math mode.
ScopedCublasMathMode(cublasHandle_t handle)167   explicit ScopedCublasMathMode(cublasHandle_t handle)
168       : handle_(handle), ok_(false) {}
169 
170   // Attempts the switch to the requested scoped math mode, new_mode.
171   //
172   // Note that when false is returned, an appropriate error has already been
173   // logged.
Init(cublasMath_t new_mode)174   bool Init(cublasMath_t new_mode) {
175     cublasStatus_t ret = cublasGetMathMode(handle_, &old_mode_);
176     if (ret != CUBLAS_STATUS_SUCCESS) {
177       LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
178       return ok_ = false;
179     }
180 
181     ret = cublasSetMathMode(handle_, new_mode);
182     if (ret != CUBLAS_STATUS_SUCCESS) {
183       LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
184       return ok_ = false;
185     }
186     return ok_ = true;
187   }
188 
189   // Switches back to the prior math mode, if the switch operation was
190   // successful in the first place.
~ScopedCublasMathMode()191   ~ScopedCublasMathMode() {
192     if (ok_) {
193       cublasStatus_t ret = cublasSetMathMode(handle_, old_mode_);
194       if (ret != CUBLAS_STATUS_SUCCESS) {
195         LOG(ERROR) << "failed to set former cublas math mode: "
196                    << ToString(ret);
197       }
198     }
199   }
200 
201  private:
202   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
203   cublasMath_t old_mode_;  // Prior cuBLAS math mode, to be restored.
204   bool ok_;                // Whether the change was successful.
205 };
206 #endif  // CUDA_VERSION >= 9000
207 
208 static const char *const kCublasNotInitializedExplanation =
209     "Failure to initialize cublas may be due to OOM (cublas needs some free "
210     "memory when you initialize it, and your deep-learning framework may have "
211     "preallocated more than its fair share), or may be because this binary was "
212     "not built with support for the GPU in your machine.";
213 
Init()214 bool CUDABlas::Init() {
215   gpu::ScopedActivateExecutorContext sac{parent_};
216   cublasStatus_t ret = cublasCreate(&blas_);
217   if (ret != CUBLAS_STATUS_SUCCESS) {
218     LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
219     if (ret == CUBLAS_STATUS_NOT_INITIALIZED) {
220       LOG(ERROR) << kCublasNotInitializedExplanation;
221     }
222     return false;
223   }
224 
225 #if CUDA_VERSION >= 11000
226   if (!blas_lt_.Init().ok()) {
227     LOG(ERROR) << kCublasNotInitializedExplanation;
228     return false;
229   }
230 #endif  // CUDA_VERSION >= 11000
231 
232   return true;
233 }
234 
CUDABlas(gpu::GpuExecutor * parent)235 CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
236     : parent_(CHECK_NOTNULL(parent)),
237       blas_(nullptr)
238 #if CUDA_VERSION >= 11000
239       ,
240       blas_lt_(parent)
241 #endif
242 {
243 }
244 
~CUDABlas()245 CUDABlas::~CUDABlas() {
246   if (blas_ != nullptr) {
247     gpu::ScopedActivateExecutorContext sac{parent_};
248     cublasDestroy(blas_);
249   }
250 }
251 
SetStream(Stream * stream)252 bool CUDABlas::SetStream(Stream *stream) {
253   CHECK(stream != nullptr);
254   CHECK(AsGpuStreamValue(stream) != nullptr);
255   CHECK(blas_ != nullptr);
256   gpu::ScopedActivateExecutorContext sac{parent_};
257   cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
258   if (ret != CUBLAS_STATUS_SUCCESS) {
259     LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
260     return false;
261   }
262 
263   return true;
264 }
265 
CUDAStream(Stream * stream)266 cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
267   CHECK(stream != nullptr);
268   CHECK(AsGpuStreamValue(stream) != nullptr);
269   gpu::ScopedActivateExecutorContext sac{parent_};
270   return AsGpuStreamValue(stream);
271 }
272 
273 namespace {
274 
275 // Helper functions transforming blas arguments into cuBLAS arguments.
276 
CUDABlasUpperLower(blas::UpperLower uplo)277 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
278   switch (uplo) {
279     case blas::UpperLower::kUpper:
280       return CUBLAS_FILL_MODE_UPPER;
281     case blas::UpperLower::kLower:
282       return CUBLAS_FILL_MODE_LOWER;
283     default:
284       LOG(FATAL) << "Invalid value of blas::UpperLower.";
285   }
286 }
287 
CUDABlasDiagonal(blas::Diagonal diag)288 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
289   switch (diag) {
290     case blas::Diagonal::kUnit:
291       return CUBLAS_DIAG_UNIT;
292     case blas::Diagonal::kNonUnit:
293       return CUBLAS_DIAG_NON_UNIT;
294     default:
295       LOG(FATAL) << "Invalid value of blas::Diagonal.";
296   }
297 }
298 
CUDABlasSide(blas::Side side)299 cublasSideMode_t CUDABlasSide(blas::Side side) {
300   switch (side) {
301     case blas::Side::kLeft:
302       return CUBLAS_SIDE_LEFT;
303     case blas::Side::kRight:
304       return CUBLAS_SIDE_RIGHT;
305     default:
306       LOG(FATAL) << "Invalid value of blas::Side.";
307   }
308 }
309 
310 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
311 // cudaDataType_t (e.g. CUDA_R_32F).
312 //
313 // These are used to build the argument type and computation type args to
314 // cublasGemmEx.
315 template <typename T>
316 struct CUDADataType;
317 
318 template <>
319 struct CUDADataType<Eigen::half> {
320   static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
321 };
322 
323 template <>
324 struct CUDADataType<std::complex<Eigen::half>> {
325   static constexpr cudaDataType_t type = CUDA_C_16F;
326 };
327 
328 template <>
329 struct CUDADataType<float> {
330   static constexpr cudaDataType_t type = CUDA_R_32F;
331 };
332 
333 template <>
334 struct CUDADataType<std::complex<float>> {
335   static constexpr cudaDataType_t type = CUDA_C_32F;
336 };
337 
338 template <>
339 struct CUDADataType<double> {
340   static constexpr cudaDataType_t type = CUDA_R_64F;
341 };
342 
343 template <>
344 struct CUDADataType<std::complex<double>> {
345   static constexpr cudaDataType_t type = CUDA_C_64F;
346 };
347 
348 template <>
349 struct CUDADataType<int> {
350   static constexpr cudaDataType_t type = CUDA_R_32I;
351 };
352 
353 template <>
354 struct CUDADataType<int8> {
355   static constexpr cudaDataType_t type = CUDA_R_8I;
356 };
357 
358 template <>
359 struct CUDADataType<std::complex<int8>> {
360   static constexpr cudaDataType_t type = CUDA_C_8I;
361 };
362 
363 template <>
364 struct CUDADataType<uint8> {
365   static constexpr cudaDataType_t type = CUDA_R_8U;
366 };
367 
368 template <>
369 struct CUDADataType<std::complex<uint8>> {
370   static constexpr cudaDataType_t type = CUDA_C_8U;
371 };
372 
373 }  // namespace
374 
375 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT cublas_func,Stream * stream,bool pointer_mode_host,cublasMath_t math_type,Args...args)376 port::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
377                                           bool pointer_mode_host,
378                                           cublasMath_t math_type,
379                                           Args... args) {
380   absl::MutexLock lock(&mu_);
381 
382   CHECK(blas_ != nullptr);
383   if (!SetStream(stream)) {
384     return port::InternalError("Failed setting stream");
385   }
386 
387 #if CUDA_VERSION >= 9000
388   ScopedCublasMathMode math_mode{blas_};
389 #if CUBLAS_VER_MAJOR >= 11
390   if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
391       tensorflow::tensor_float_32_execution_enabled()) {
392 #else
393   if (math_type == CUBLAS_TENSOR_OP_MATH) {
394 #endif
395     if (!math_mode.Init(math_type)) {
396       return port::InternalError("Failed initializing math mode");
397     }
398   }
399 #endif
400 
401   gpu::ScopedActivateExecutorContext sac{parent_};
402   ScopedCublasPointerMode pointer_mode{blas_};
403   if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
404                                            : CUBLAS_POINTER_MODE_DEVICE)) {
405     return port::InternalError("Failed setting error mode");
406   }
407   cublasStatus_t ret = cublas_func(blas_, args...);
408   if (ret == CUBLAS_STATUS_SUCCESS) {
409     return ::tensorflow::OkStatus();
410   }
411   return port::InternalError(ToString(ret));
412 }
413 
414 // cublas_func may be overloaded, so we need to figure out which one we really
415 // need to call based on the args. One way to do it is to wrap it in lambda.
416 #define AS_LAMBDA(func)                                            \
417   [](auto &&...args) -> decltype(func(                             \
418                          std::forward<decltype(args)>(args)...)) { \
419     return func(std::forward<decltype(args)>(args)...);            \
420   }
421 
422 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,
423                           const DeviceMemory<float> &x, int incx,
424                           DeviceMemory<float> *y, int incy) {
425   return DoBlasInternal(cublasSaxpy, stream, true /* = pointer_mode_host */,
426                         elem_count, &alpha, GpuMemory(x), incx,
427                         GpuMemoryMutable(y), incy);
428 }
429 
430 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha,
431                           const DeviceMemory<double> &x, int incx,
432                           DeviceMemory<double> *y, int incy) {
433   return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */,
434                         elem_count, &alpha, GpuMemory(x), incx,
435                         GpuMemoryMutable(y), incy);
436 }
437 
438 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
439                           std::complex<float> alpha,
440                           const DeviceMemory<std::complex<float>> &x, int incx,
441                           DeviceMemory<std::complex<float>> *y, int incy) {
442   auto cb_alpha = GpuComplexValue(alpha);
443   return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
444                         elem_count, GpuComplex(&cb_alpha),
445                         GpuComplex(GpuMemory(x)), incx,
446                         GpuComplex(GpuMemoryMutable(y)), incy);
447 }
448 
449 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
450                           std::complex<double> alpha,
451                           const DeviceMemory<std::complex<double>> &x, int incx,
452                           DeviceMemory<std::complex<double>> *y, int incy) {
453   auto cb_alpha = GpuComplexValue(alpha);
454   return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
455                         elem_count, GpuComplex(&cb_alpha),
456                         GpuComplex(GpuMemory(x)), incx,
457                         GpuComplex(GpuMemoryMutable(y)), incy);
458 }
459 
460 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
461                           const DeviceMemory<float> &x, int incx,
462                           DeviceMemory<float> *y, int incy) {
463   return DoBlasInternal(cublasScopy, stream, true /* = pointer_mode_host */,
464                         elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
465                         incy);
466 }
467 
468 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
469                           const DeviceMemory<double> &x, int incx,
470                           DeviceMemory<double> *y, int incy) {
471   return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */,
472                         elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
473                         incy);
474 }
475 
476 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
477                           const DeviceMemory<std::complex<float>> &x, int incx,
478                           DeviceMemory<std::complex<float>> *y, int incy) {
479   return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */,
480                         elem_count, GpuComplex(GpuMemory(x)), incx,
481                         GpuComplex(GpuMemoryMutable(y)), incy);
482 }
483 
484 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
485                           const DeviceMemory<std::complex<double>> &x, int incx,
486                           DeviceMemory<std::complex<double>> *y, int incy) {
487   return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */,
488                         elem_count, GpuComplex(GpuMemory(x)), incx,
489                         GpuComplex(GpuMemoryMutable(y)), incy);
490 }
491 
492 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
493                           DeviceMemory<float> *x, int incx) {
494   return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */,
495                         elem_count, &alpha, GpuMemoryMutable(x), incx);
496 }
497 
498 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
499                           DeviceMemory<double> *x, int incx) {
500   return DoBlasInternal(cublasDscal, stream, true /* = pointer_mode_host */,
501                         elem_count, &alpha, GpuMemoryMutable(x), incx);
502 }
503 
504 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
505                           DeviceMemory<std::complex<float>> *x, int incx) {
506   return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
507                         elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
508                         incx);
509 }
510 
511 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
512                           DeviceMemory<std::complex<double>> *x, int incx) {
513   return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
514                         elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
515                         incx);
516 }
517 
518 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count,
519                           std::complex<float> alpha,
520                           DeviceMemory<std::complex<float>> *x, int incx) {
521   auto cb_alpha = GpuComplexValue(alpha);
522   return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
523                         elem_count, GpuComplex(&cb_alpha),
524                         GpuComplex(GpuMemoryMutable(x)), incx);
525 }
526 
527 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count,
528                           std::complex<double> alpha,
529                           DeviceMemory<std::complex<double>> *x, int incx) {
530   auto cb_alpha = GpuComplexValue(alpha);
531   return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
532                         elem_count, GpuComplex(&cb_alpha),
533                         GpuComplex(GpuMemoryMutable(x)), incx);
534 }
535 
536 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
537                           uint64_t n, float alpha, const DeviceMemory<float> &a,
538                           int lda, const DeviceMemory<float> &x, int incx,
539                           float beta, DeviceMemory<float> *y, int incy) {
540   return DoBlasInternal(cublasSgemv, stream, true /* = pointer_mode_host */,
541                         AsCublasOperation(trans), m, n, &alpha, GpuMemory(a),
542                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
543                         incy);
544 }
545 
546 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
547                           uint64_t n, double alpha,
548                           const DeviceMemory<double> &a, int lda,
549                           const DeviceMemory<double> &x, int incx, double beta,
550                           DeviceMemory<double> *y, int incy) {
551   return DoBlasInternal(cublasDgemv, stream, true /* = pointer_mode_host */,
552                         AsCublasOperation(trans), m, n, &alpha, GpuMemory(a),
553                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
554                         incy);
555 }
556 
557 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
558                           uint64_t n, std::complex<float> alpha,
559                           const DeviceMemory<std::complex<float>> &a, int lda,
560                           const DeviceMemory<std::complex<float>> &x, int incx,
561                           std::complex<float> beta,
562                           DeviceMemory<std::complex<float>> *y, int incy) {
563   auto cb_alpha = GpuComplexValue(alpha);
564   auto cb_beta = GpuComplexValue(beta);
565   return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
566                         AsCublasOperation(trans), m, n, GpuComplex(&cb_alpha),
567                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
568                         incx, GpuComplex(&cb_beta),
569                         GpuComplex(GpuMemoryMutable(y)), incy);
570 }
571 
572 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
573                           uint64_t n, std::complex<double> alpha,
574                           const DeviceMemory<std::complex<double>> &a, int lda,
575                           const DeviceMemory<std::complex<double>> &x, int incx,
576                           std::complex<double> beta,
577                           DeviceMemory<std::complex<double>> *y, int incy) {
578   auto cb_alpha = GpuComplexValue(alpha);
579   auto cb_beta = GpuComplexValue(beta);
580   return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
581                         AsCublasOperation(trans), m, n, GpuComplex(&cb_alpha),
582                         GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
583                         incx, GpuComplex(&cb_beta),
584                         GpuComplex(GpuMemoryMutable(y)), incy);
585 }
586 
587 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
588                           uint64_t k, float alpha, const DeviceMemory<float> &a,
589                           int lda, const DeviceMemory<float> &x, int incx,
590                           float beta, DeviceMemory<float> *y, int incy) {
591   return DoBlasInternal(cublasSsbmv, stream, true /* = pointer_mode_host */,
592                         CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
593                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
594                         incy);
595 }
596 
597 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
598                           uint64_t k, double alpha,
599                           const DeviceMemory<double> &a, int lda,
600                           const DeviceMemory<double> &x, int incx, double beta,
601                           DeviceMemory<double> *y, int incy) {
602   return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */,
603                         CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
604                         lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
605                         incy);
606 }
607 
608 port::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
609                                   blas::Transpose transb, uint64_t m, uint64 n,
610                                   uint64_t k, blas::DataType dtype,
611                                   const void *alpha, const DeviceMemoryBase &a,
612                                   int lda, const DeviceMemoryBase &b, int ldb,
613                                   const void *beta, DeviceMemoryBase *c,
614                                   int ldc, blas::ComputePrecision precision) {
615   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
616 
617 #if CUDA_VERSION < 11000
618   if (dtype == blas::DataType::kHalf) {
619     math_type = CUBLAS_TENSOR_OP_MATH;
620   }
621 #else
622   if (dtype == blas::DataType::kFloat) {
623     math_type = CUBLAS_TF32_TENSOR_OP_MATH;
624     if (stream->GetCudaComputeCapability().IsAtLeast(
625             CudaComputeCapability::AMPERE)) {
626       // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
627       // well tested.
628       if (tensorflow::tensor_float_32_execution_enabled()) {
629         LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
630                                 "multiplication. This will only be logged "
631                                 "once.";
632       }
633     }
634     if (precision > blas::kDefaultComputePrecision) {
635       math_type = CUBLAS_DEFAULT_MATH;
636     }
637   }
638 #endif
639 
640   // TODO(cheshire): Return an error instead.
641   // TODO(cheshire): Why are these checked only for `half` and `float`?
642   if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) {
643     if (transa == blas::Transpose::kNoTranspose) {
644       if (lda < static_cast<int64_t>(m)) {
645         LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
646                         "precondition violation";
647       }
648     } else {
649       if (lda < static_cast<int64_t>(k)) {
650         LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
651                      << ") (transpose case); precondition violation";
652       }
653     }
654     if (transb == blas::Transpose::kNoTranspose) {
655       if (ldb < static_cast<int64_t>(k)) {
656         LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
657                      << ") (no transpose case); precondition violation";
658       }
659     } else {
660       if (ldb < static_cast<int64_t>(n)) {
661         LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
662                         "precondition violation";
663       }
664     }
665   }
666 
667   VLOG(1) << absl::StrFormat(
668       "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
669       "k=%u alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
670       "c=%p ldc=%d",
671       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
672       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
673 
674   switch (dtype) {
675     case blas::DataType::kHalf: {
676 #if CUDA_VERSION < 7050
677       return port::InternalError(
678           "fp16 sgemm is not implemented in this cuBLAS version "
679           "(need at least CUDA 7.5)");
680 #endif
681 
682       return DoBlasInternalImpl(
683           cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
684           AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
685           static_cast<const float *>(alpha), a.opaque(), SE_CUDA_DATA_HALF, lda,
686           b.opaque(), SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
687           c->opaque(), SE_CUDA_DATA_HALF, ldc);
688     }
689 #if CUDA_VERSION > 11000
690     case blas::DataType::kBF16: {
691       return DoBlasInternalImpl(
692           cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
693           AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
694           static_cast<const float *>(alpha), a.opaque(), CUDA_R_16BF, lda,
695           b.opaque(), CUDA_R_16BF, ldb, static_cast<const float *>(beta),
696           c->opaque(), CUDA_R_16BF, ldc);
697     }
698 #endif
699     case dnn::kFloat:
700       return DoBlasInternalImpl(
701           cublasSgemm, stream, true /* = pointer_mode_host */, math_type,
702           AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
703           static_cast<const float *>(alpha),
704           static_cast<const float *>(a.opaque()), lda,
705           static_cast<const float *>(b.opaque()), ldb,
706           static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
707           ldc);
708     case dnn::kDouble:
709       return DoBlasInternalImpl(
710           cublasDgemm, stream, true /* = pointer_mode_host */, math_type,
711           AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
712           static_cast<const double *>(alpha),
713           static_cast<const double *>(a.opaque()), lda,
714           static_cast<const double *>(b.opaque()), ldb,
715           static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
716           ldc);
717     case dnn::kComplexFloat: {
718       GpuComplexType cb_alpha =
719           GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
720       GpuComplexType cb_beta =
721           GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
722       return DoBlasInternalImpl(
723           cublasCgemm, stream, true /* = pointer_mode_host */, math_type,
724           AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
725           &cb_alpha, static_cast<const GpuComplexType *>(a.opaque()), lda,
726           static_cast<const GpuComplexType *>(b.opaque()), ldb, &cb_beta,
727           static_cast<GpuComplexType *>(c->opaque()), ldc);
728     }
729     case dnn::kComplexDouble: {
730       GpuDoubleComplexType cb_alpha =
731           GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
732       GpuDoubleComplexType cb_beta =
733           GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
734       return DoBlasInternalImpl(
735           cublasZgemm, stream, true /* = pointer_mode_host */, math_type,
736           AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
737           &cb_alpha, static_cast<const GpuDoubleComplexType *>(a.opaque()), lda,
738           static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, &cb_beta,
739           static_cast<GpuDoubleComplexType *>(c->opaque()), ldc);
740     }
741     default:
742       return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
743                                               blas::DataTypeString(dtype)));
744   }
745 }
746 
747 bool CUDABlas::DoBlasGemvWithProfiling(
748     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, float alpha,
749     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
750     int incx, float beta, DeviceMemory<float> *y, int incy,
751     blas::ProfileResult *output_profile_result) {
752   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
753                                      incx, beta, y, incy,
754                                      output_profile_result);
755 }
756 
757 bool CUDABlas::DoBlasGemvWithProfiling(
758     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, double alpha,
759     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
760     int incx, double beta, DeviceMemory<double> *y, int incy,
761     blas::ProfileResult *output_profile_result) {
762   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
763                                      incx, beta, y, incy,
764                                      output_profile_result);
765 }
766 
767 bool CUDABlas::DoBlasGemvWithProfiling(
768     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
769     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
770     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
771     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
772     blas::ProfileResult *output_profile_result) {
773   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
774                                      incx, beta, y, incy,
775                                      output_profile_result);
776 }
777 
778 bool CUDABlas::DoBlasGemvWithProfiling(
779     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
780     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
781     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
782     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
783     blas::ProfileResult *output_profile_result) {
784   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
785                                      incx, beta, y, incy,
786                                      output_profile_result);
787 }
788 
789 bool CUDABlas::DoBlasGemmWithProfiling(
790     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
791     uint64_t n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
792     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
793     DeviceMemory<Eigen::half> *c, int ldc,
794     blas::ProfileResult *output_profile_result) {
795   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
796                                      lda, b, ldb, beta, c, ldc,
797                                      output_profile_result);
798 }
799 
800 bool CUDABlas::DoBlasGemmWithProfiling(
801     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
802     uint64_t n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
803     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
804     int ldc, blas::ProfileResult *output_profile_result) {
805   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
806                                      lda, b, ldb, beta, c, ldc,
807                                      output_profile_result);
808 }
809 
810 bool CUDABlas::DoBlasGemmWithProfiling(
811     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
812     uint64_t n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
813     const DeviceMemory<double> &b, int ldb, double beta,
814     DeviceMemory<double> *c, int ldc,
815     blas::ProfileResult *output_profile_result) {
816   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
817                                      lda, b, ldb, beta, c, ldc,
818                                      output_profile_result);
819 }
820 
821 bool CUDABlas::DoBlasGemmWithProfiling(
822     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
823     uint64_t n, uint64 k, std::complex<float> alpha,
824     const DeviceMemory<std::complex<float>> &a, int lda,
825     const DeviceMemory<std::complex<float>> &b, int ldb,
826     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
827     blas::ProfileResult *output_profile_result) {
828   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
829                                      lda, b, ldb, beta, c, ldc,
830                                      output_profile_result);
831 }
832 
833 bool CUDABlas::DoBlasGemmWithProfiling(
834     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
835     uint64_t n, uint64 k, std::complex<double> alpha,
836     const DeviceMemory<std::complex<double>> &a, int lda,
837     const DeviceMemory<std::complex<double>> &b, int ldb,
838     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
839     blas::ProfileResult *output_profile_result) {
840   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
841                                      lda, b, ldb, beta, c, ldc,
842                                      output_profile_result);
843 }
844 
845 template <typename T>
846 bool CUDABlas::DoBlasGemvWithProfilingImpl(
847     Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, const T &alpha,
848     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
849     const T &beta, DeviceMemory<T> *y, int incy,
850     blas::ProfileResult *output_profile_result) {
851   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
852   if (output_profile_result != nullptr) {
853     timer.reset(new GpuTimer(parent_));
854     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
855       return false;
856     }
857   }
858 
859   // Call blasGemm
860   bool result =
861       DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
862 
863   if (timer != nullptr && result) {
864     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
865     // state.
866     if (!timer->Stop(AsGpuStream(stream))) {
867       return false;
868     }
869     output_profile_result->set_is_valid(true);
870     output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
871     output_profile_result->set_elapsed_time_in_ms(
872         timer->GetElapsedMilliseconds());
873   }
874   return result;
875 }
876 
877 template <typename T, typename ParamType>
878 bool CUDABlas::DoBlasGemmWithProfilingImpl(
879     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
880     uint64_t n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
881     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
882     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
883   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
884   if (output_profile_result != nullptr) {
885     timer.reset(new GpuTimer(parent_));
886     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
887       return false;
888     }
889   }
890 
891   // Call blasGemm
892   bool result = DoBlasGemm(stream, transa, transb, m, n, k,
893                            blas::ToDataType<T>::value, &alpha, a, lda, b, ldb,
894                            &beta, c, ldc, blas::kDefaultComputePrecision)
895                     .ok();
896 
897   if (timer != nullptr && result) {
898     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
899     // state.
900     if (!timer->Stop(AsGpuStream(stream))) {
901       return false;
902     }
903     output_profile_result->set_is_valid(true);
904     output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
905     output_profile_result->set_elapsed_time_in_ms(
906         timer->GetElapsedMilliseconds());
907   }
908   return result;
909 }
910 
911 static bool UsesTensorOps(blas::AlgorithmType algo) {
912 #if CUDA_VERSION >= 9000
913   cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
914   return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
915 #else
916   return false;
917 #endif
918 }
919 
920 static port::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
921     Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
922     blas::DataType type_b) {
923   if (type_a != type_b) {
924     return port::InternalError("Types of inputs mismatch");
925   }
926 
927   // GPUs < sm_50 don't support cublasGemmEx.
928   CudaComputeCapability cc = stream->GetCudaComputeCapability();
929   if (cc.major < 5) {
930     return port::InternalError(absl::StrCat(
931         "sm_", cc.major, " does not support explicit gemm algorithms."));
932   }
933 
934   bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
935   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
936   if (algo_uses_tensor_ops) {
937     if (cc.major < 7) {
938       return port::InternalError(absl::StrCat(
939           "Algorithm ", algorithm,
940           " uses tensor ops, but tensor ops are not available in sm", cc.major,
941           "X devices."));
942     } else if (type_a == blas::DataType::kFloat) {
943 #if CUDA_VERSION < 11000
944       return port::InternalError(absl::StrCat(
945           "Algorithm ", algorithm,
946           " uses tensor ops, but tensor ops are not available for fp32"));
947 #else
948       if (cc.major < 8) {
949         return port::InternalError(absl::StrCat(
950             "Algorithm ", algorithm,
951             " uses tensor ops, but tensor ops are not available in sm",
952             cc.major, "X devices for float input types."));
953       } else if (!tensorflow::tensor_float_32_execution_enabled()) {
954         return port::InternalError(absl::StrCat(
955             "Algorithm ", algorithm,
956             " uses tensor ops, but tensor ops are disabled for fp32 inputs"));
957       }
958       math_type = CUBLAS_TF32_TENSOR_OP_MATH;
959 #endif
960     } else if (type_a == blas::DataType::kHalf) {
961 #if CUDA_VERSION < 11000
962       math_type = CUBLAS_TENSOR_OP_MATH;
963 #endif
964     } else {
965       return port::InternalError(
966           absl::StrCat("Algorithm ", algorithm,
967                        " uses tensor ops which are not supported for input"));
968     }
969   }
970 
971   // Return false if we might be hitting a cuBLAS bug that produces the wrong
972   // result. See nvbugs/2156201, b/79126339.
973 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
974   if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
975       std::max({m, n, k}) >= 2097153 && cc_major < 7) {
976     return port::InternalError(
977         "DoBlasGemmWithAlgorithm returning false to work around cudnn "
978         "<9.2 bug with m, n, or k >= 2097153.  See b/79126339.");
979   }
980 #endif
981   return math_type;
982 }
983 
984 static port::StatusOr<std::unique_ptr<GpuTimer, GpuTimerDeleter>>
985 StartGpuTimerForProfile(Stream *stream, GpuExecutor *executor,
986                         blas::ProfileResult *output_profile_result) {
987   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
988   if (output_profile_result) {
989     timer.reset(new GpuTimer(executor));
990     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
991       return port::InternalError(
992           "output_profile_result given, but unable to create a GpuTimer");
993     }
994   }
995   return timer;
996 }
997 
998 static port::Status PopulateProfileFromTimer(
999     GpuTimer *timer, blas::AlgorithmType algorithm,
1000     blas::ProfileResult *output_profile_result, Stream *stream) {
1001   if (timer) {
1002     // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
1003     // state.
1004     if (!timer->Stop(AsGpuStream(stream))) {
1005       return port::InternalError("unable to stop GpuTimer.");
1006     }
1007     output_profile_result->set_is_valid(true);
1008     output_profile_result->set_algorithm(algorithm);
1009     output_profile_result->set_elapsed_time_in_ms(
1010         timer->GetElapsedMilliseconds());
1011   }
1012   return ::tensorflow::OkStatus();
1013 }
1014 
1015 port::Status CUDABlas::DoBlasGemmWithAlgorithm(
1016     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1017     uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
1018     blas::DataType type_a, int lda, const DeviceMemoryBase &b,
1019     blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
1020     blas::DataType type_c, int ldc, blas::ComputationType computation_type,
1021     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1022   TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
1023                       GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
1024 
1025   TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
1026                                       stream, parent_, output_profile_result));
1027 
1028   // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
1029   // we do the following compile-time check on the default value:
1030   static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
1031 
1032   TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1033       AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
1034       AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, alpha,
1035       a.opaque(), AsCudaDataType(type_a), lda, b.opaque(),
1036       AsCudaDataType(type_b), ldb, beta, c->opaque(), AsCudaDataType(type_c),
1037       ldc, AsCublasComputeType(computation_type),
1038       static_cast<cublasGemmAlgo_t>(algorithm)));
1039   TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
1040                                               output_profile_result, stream));
1041   return ::tensorflow::OkStatus();
1042 }
1043 
1044 port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
1045     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1046     uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
1047     blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
1048     blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
1049     DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
1050     int batch_count, blas::ComputationType computation_type,
1051     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1052   TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
1053                       GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
1054   TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
1055                                       stream, parent_, output_profile_result));
1056 
1057   cudaDataType_t cuda_in_type = AsCudaDataType(type_a);
1058 
1059 #if CUDA_VERSION >= 11000
1060   // Workaround CUDA bug where batched GEMM is erroneously marked as
1061   // unsupported by manually unbatching it on Pascal.
1062   if (cuda_in_type == CUDA_R_16BF &&
1063       !stream->GetCudaComputeCapability().IsAtLeast(7)) {
1064     for (int batch = 0; batch < batch_count; ++batch) {
1065       const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1066           static_cast<const Eigen::bfloat16 *>(a.opaque()) + batch * stride_a);
1067       const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1068           static_cast<const Eigen::bfloat16 *>(b.opaque()) + batch * stride_b);
1069       auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
1070           static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
1071       TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1072           AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true,
1073           math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1074           k, static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
1075           b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
1076           c_matrix, CUDA_R_16BF, ldc, AsCublasComputeType(computation_type),
1077           static_cast<cublasGemmAlgo_t>(algorithm)));
1078     }
1079     TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
1080                                                 output_profile_result, stream));
1081     return port::Status::OK();
1082   }
1083 #endif
1084 
1085   TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1086       AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
1087       math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1088       alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type,
1089       ldb, stride_b, beta, c->opaque(), AsCudaDataType(type_c), ldc, stride_c,
1090       batch_count, AsCublasComputeType(computation_type),
1091       static_cast<cublasGemmAlgo_t>(algorithm)));
1092   TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
1093                                               output_profile_result, stream));
1094   return ::tensorflow::OkStatus();
1095 }
1096 
1097 bool CUDABlas::GetBlasGemmAlgorithms(
1098     Stream *stream, std::vector<blas::AlgorithmType> *out_algorithms) {
1099   // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
1100   // were first introduced in CUDA 8.
1101   //
1102   // Note that when CUDA version and compute capability is not sufficient, we
1103   // still return the out_algorithms. Caller needs to make sure that in this
1104   // case, the returned vector is empty.
1105   if (stream->GetCudaComputeCapability().IsAtLeast(
1106           CudaComputeCapability::AMPERE)) {
1107     // Note: for NVIDIA Ampere Architecture GPUs and beyond, i.e. SM version >=
1108     // 80, the numbered algorithm options are equivalent to CUBLAS_GEMM_DEFAULT
1109     // or CUBLAS_GEMM_DEFAULT_TENSOR_OP respectively.
1110     *out_algorithms = {
1111         CUBLAS_GEMM_DFALT,
1112         CUBLAS_GEMM_DFALT_TENSOR_OP,
1113     };
1114   } else {
1115     *out_algorithms = {
1116       CUBLAS_GEMM_DFALT,
1117       CUBLAS_GEMM_ALGO0,
1118       CUBLAS_GEMM_ALGO1,
1119       CUBLAS_GEMM_ALGO2,
1120       CUBLAS_GEMM_ALGO3,
1121       CUBLAS_GEMM_ALGO4,
1122       CUBLAS_GEMM_ALGO5,
1123       CUBLAS_GEMM_ALGO6,
1124       CUBLAS_GEMM_ALGO7,
1125 #if CUDA_VERSION >= 9000
1126       CUBLAS_GEMM_ALGO8,
1127       CUBLAS_GEMM_ALGO9,
1128       CUBLAS_GEMM_ALGO10,
1129       CUBLAS_GEMM_ALGO11,
1130       CUBLAS_GEMM_ALGO12,
1131       CUBLAS_GEMM_ALGO13,
1132       CUBLAS_GEMM_ALGO14,
1133       CUBLAS_GEMM_ALGO15,
1134       CUBLAS_GEMM_ALGO16,
1135       CUBLAS_GEMM_ALGO17,
1136       CUBLAS_GEMM_DFALT_TENSOR_OP,
1137       CUBLAS_GEMM_ALGO0_TENSOR_OP,
1138       CUBLAS_GEMM_ALGO1_TENSOR_OP,
1139       CUBLAS_GEMM_ALGO2_TENSOR_OP,
1140       CUBLAS_GEMM_ALGO3_TENSOR_OP,
1141       CUBLAS_GEMM_ALGO4_TENSOR_OP,
1142 #endif
1143 #if CUDA_VERSION >= 9020
1144       CUBLAS_GEMM_ALGO18,
1145       CUBLAS_GEMM_ALGO19,
1146       CUBLAS_GEMM_ALGO20,
1147       CUBLAS_GEMM_ALGO21,
1148       CUBLAS_GEMM_ALGO22,
1149       CUBLAS_GEMM_ALGO23,
1150       CUBLAS_GEMM_ALGO5_TENSOR_OP,
1151       CUBLAS_GEMM_ALGO6_TENSOR_OP,
1152       CUBLAS_GEMM_ALGO7_TENSOR_OP,
1153       CUBLAS_GEMM_ALGO8_TENSOR_OP,
1154       CUBLAS_GEMM_ALGO9_TENSOR_OP,
1155       CUBLAS_GEMM_ALGO10_TENSOR_OP,
1156       CUBLAS_GEMM_ALGO11_TENSOR_OP,
1157       CUBLAS_GEMM_ALGO12_TENSOR_OP,
1158       CUBLAS_GEMM_ALGO13_TENSOR_OP,
1159       CUBLAS_GEMM_ALGO14_TENSOR_OP,
1160       CUBLAS_GEMM_ALGO15_TENSOR_OP,
1161 #endif
1162     };
1163   }
1164   return true;
1165 }
1166 
1167 template <typename T>
1168 struct HalfAsFloat {
1169   typedef T type;
1170 };
1171 
1172 template <>
1173 struct HalfAsFloat<Eigen::half> {
1174   typedef float type;
1175 };
1176 
1177 namespace {
1178 // pass-through for non-complex types that don't need conversion to
1179 // cublas-specific type.
1180 template <typename T>
1181 T inline GpuComplexValue(T v) {
1182   return v;
1183 }
1184 }  // namespace
1185 
1186 template <typename T, typename Scalar, typename FuncT>
1187 port::Status CUDABlas::DoBlasGemmBatchedInternal(
1188     FuncT cublas_func, Stream *stream, blas::Transpose transa,
1189     blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
1190     const DeviceMemorySlice<T> &a_ptrs_to_wrappers, int lda,
1191     const DeviceMemorySlice<T> &b_ptrs_to_wrappers, int ldb, Scalar beta,
1192     const DeviceMemorySlice<T> &c_ptrs_to_wrappers, int ldc, int batch_count,
1193     ScratchAllocator *scratch_allocator) {
1194   std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
1195   for (int i = 0; i < batch_count; ++i) {
1196     a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
1197     b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
1198     c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
1199   }
1200 
1201   typedef typename HalfAsFloat<typename GpuComplexT<T>::type>::type CUDA_T;
1202 
1203   const size_t size = batch_count * sizeof(CUDA_T *);
1204 
1205   // Device-side copy of pointers to matrices.
1206   DeviceMemory<CUDA_T *> a;
1207   DeviceMemory<CUDA_T *> b;
1208   DeviceMemory<CUDA_T *> c;
1209 
1210   // If temporary space is allocated for device-side copies of pointers to
1211   // matrices, that temporary space should not be freed until this function
1212   // returns. Although the values for these unique_ptrs are not set here, they
1213   // are declared at this scope so they will be destroyed when the function
1214   // returns.
1215   //
1216   // If a scratch allocator is provided, these pointers will not be used at all.
1217   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
1218   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
1219   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
1220 
1221   // Decide how to allocate device-side copy of pointers to matrices based on
1222   // whether a scratch allocator was passed.
1223   if (scratch_allocator != nullptr) {
1224     TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
1225                         scratch_allocator->AllocateBytes(size));
1226     TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
1227                         scratch_allocator->AllocateBytes(size));
1228     TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
1229                         scratch_allocator->AllocateBytes(size));
1230     a = DeviceMemory<CUDA_T *>(a_bytes);
1231     b = DeviceMemory<CUDA_T *>(b_bytes);
1232     c = DeviceMemory<CUDA_T *>(c_bytes);
1233   } else {
1234     TF_ASSIGN_OR_RETURN(a_temporary,
1235                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
1236     TF_ASSIGN_OR_RETURN(b_temporary,
1237                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
1238     TF_ASSIGN_OR_RETURN(c_temporary,
1239                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
1240     a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
1241     b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
1242     c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
1243   }
1244 
1245   if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
1246       !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
1247       !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
1248     return port::Status(port::error::INTERNAL,
1249                         "failed to copy memory from host to device in "
1250                         "CUDABlas::DoBlasGemmBatched");
1251   }
1252 
1253   cudaDataType_t data_type = CUDADataType<T>::type;
1254 
1255 #if CUDA_VERSION >= 9010
1256   if (stream->GetCudaComputeCapability().IsAtLeast(5)) {
1257     cublasMath_t math_type;
1258     cublasGemmAlgo_t algo;
1259     if (data_type == CUDA_R_16F) {
1260 #if CUDA_VERSION < 11000
1261       math_type = CUBLAS_TENSOR_OP_MATH;
1262 #else
1263       math_type = CUBLAS_DEFAULT_MATH;
1264 #endif
1265       algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
1266 #if CUBLAS_VER_MAJOR >= 11
1267     } else if (data_type == CUDA_R_32F) {
1268       // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
1269       // if TensorFloat-32 is disabled.
1270       math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1271       algo = tensorflow::tensor_float_32_execution_enabled()
1272                  ? CUBLAS_GEMM_DFALT_TENSOR_OP
1273                  : CUBLAS_GEMM_DFALT;
1274 #endif
1275     } else {
1276       math_type = CUBLAS_DEFAULT_MATH;
1277       algo = CUBLAS_GEMM_DFALT;
1278     }
1279     cudaDataType_t compute_type =
1280         (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
1281     const void **a_void_ptrs = reinterpret_cast<const void **>(
1282         const_cast<const CUDA_T **>(GpuMemory(a)));
1283     const void **b_void_ptrs = reinterpret_cast<const void **>(
1284         const_cast<const CUDA_T **>(GpuMemory(b)));
1285     void **c_void_ptrs =
1286         reinterpret_cast<void **>(const_cast<CUDA_T **>(GpuMemory(c)));
1287     return DoBlasInternalImpl(
1288         AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
1289         math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1290         k, &alpha, a_void_ptrs, data_type, lda, b_void_ptrs, data_type, ldb,
1291         &beta, c_void_ptrs, data_type, ldc, batch_count, compute_type, algo);
1292   }
1293 #endif
1294   // either CUDA_VERSION < 9.1 or SM < 5.0
1295   if (data_type != CUDA_R_16F) {
1296     auto cb_alpha = GpuComplexValue(alpha);
1297     auto cb_beta = GpuComplexValue(beta);
1298     bool ok = DoBlasInternal(
1299         cublas_func, stream, true /* = pointer_mode_host */,
1300         AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1301         GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
1302         const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
1303         const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
1304     if (ok) {
1305       return ::tensorflow::OkStatus();
1306     }
1307     return port::Status(port::error::INTERNAL,
1308                         "failed BLAS call, see log for details");
1309   } else {
1310     // Fall back to a loop for fp16
1311     for (int b = 0; b < batch_count; ++b) {
1312       const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
1313       const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
1314       DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
1315       TF_RETURN_IF_ERROR(DoBlasGemm(
1316           stream, transa, transb, m, n, k, blas::ToDataType<T>::value, &alpha,
1317           a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc,
1318           blas::kDefaultComputePrecision));
1319     }
1320     return ::tensorflow::OkStatus();
1321   }
1322 }
1323 
1324 bool CUDABlas::DoBlasGemmBatched(
1325     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1326     uint64_t n, uint64 k, float alpha,
1327     const DeviceMemorySlice<Eigen::half> &a_array, int lda,
1328     const DeviceMemorySlice<Eigen::half> &b_array, int ldb, float beta,
1329     const DeviceMemorySlice<Eigen::half> &c_array, int ldc, int batch_count,
1330     ScratchAllocator *scratch_allocator) {
1331   // Note: The func passed here (cublasSgemmBatched) is not actually called,
1332   // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
1333   port::Status status = DoBlasGemmBatchedInternal(
1334       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1335       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1336   if (!status.ok()) {
1337     LOG(ERROR) << status;
1338   }
1339   return status.ok();
1340 }
1341 
1342 bool CUDABlas::DoBlasGemmBatched(
1343     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1344     uint64_t n, uint64 k, float alpha, const DeviceMemorySlice<float> &a_array,
1345     int lda, const DeviceMemorySlice<float> &b_array, int ldb, float beta,
1346     const DeviceMemorySlice<float> &c_array, int ldc, int batch_count,
1347     ScratchAllocator *scratch_allocator) {
1348   port::Status status = DoBlasGemmBatchedInternal(
1349       cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1350       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1351   if (!status.ok()) {
1352     LOG(ERROR) << status;
1353   }
1354   return status.ok();
1355 }
1356 
1357 bool CUDABlas::DoBlasGemmBatched(
1358     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1359     uint64_t n, uint64 k, double alpha,
1360     const DeviceMemorySlice<double> &a_array, int lda,
1361     const DeviceMemorySlice<double> &b_array, int ldb, double beta,
1362     const DeviceMemorySlice<double> &c_array, int ldc, int batch_count,
1363     ScratchAllocator *scratch_allocator) {
1364   port::Status status = DoBlasGemmBatchedInternal(
1365       cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1366       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1367   if (!status.ok()) {
1368     LOG(ERROR) << status;
1369   }
1370   return status.ok();
1371 }
1372 
1373 bool CUDABlas::DoBlasGemmBatched(
1374     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1375     uint64_t n, uint64 k, std::complex<float> alpha,
1376     const DeviceMemorySlice<std::complex<float>> &a_array, int lda,
1377     const DeviceMemorySlice<std::complex<float>> &b_array, int ldb,
1378     std::complex<float> beta,
1379     const DeviceMemorySlice<std::complex<float>> &c_array, int ldc,
1380     int batch_count, ScratchAllocator *scratch_allocator) {
1381   port::Status status = DoBlasGemmBatchedInternal(
1382       cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1383       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1384   if (!status.ok()) {
1385     LOG(ERROR) << status;
1386   }
1387   return status.ok();
1388 }
1389 
1390 bool CUDABlas::DoBlasGemmBatched(
1391     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1392     uint64_t n, uint64 k, std::complex<double> alpha,
1393     const DeviceMemorySlice<std::complex<double>> &a_array, int lda,
1394     const DeviceMemorySlice<std::complex<double>> &b_array, int ldb,
1395     std::complex<double> beta,
1396     const DeviceMemorySlice<std::complex<double>> &c_array, int ldc,
1397     int batch_count, ScratchAllocator *scratch_allocator) {
1398   port::Status status = DoBlasGemmBatchedInternal(
1399       cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1400       b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1401   if (!status.ok()) {
1402     LOG(ERROR) << status;
1403   }
1404   return status.ok();
1405 }
1406 
1407 port::Status CUDABlas::DoBlasGemmStridedBatched(
1408     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1409     uint64_t n, uint64 k, blas::DataType dtype, const void *alpha,
1410     const DeviceMemoryBase &a, int lda, int64_t stride_a,
1411     const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
1412     DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) {
1413   cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1414 #if CUDA_VERSION < 11000
1415   if (dtype == dnn::kHalf) {
1416     math_type = CUBLAS_TENSOR_OP_MATH;
1417   }
1418 #else
1419   if (dtype == dnn::kFloat) {
1420     math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1421   }
1422 #endif
1423 
1424   switch (dtype) {
1425 #if CUDA_VERSION >= 11000
1426     case dnn::kBF16: {
1427       CudaComputeCapability cc = stream->GetCudaComputeCapability();
1428       if (cc.IsAtLeast(7)) {
1429         cublasGemmAlgo_t algo =
1430             (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
1431         return DoBlasInternalImpl(
1432             AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
1433             true /* = pointer_mode_host */, math_type,
1434             AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1435             alpha, a.opaque(), CUDA_R_16BF, lda, stride_a, b.opaque(),
1436             CUDA_R_16BF, ldb, stride_b, beta, c->opaque(), CUDA_R_16BF, ldc,
1437             stride_c, batch_count,
1438             /*compute_type=*/CUDA_R_32F, algo);
1439       }
1440       // Fall back to a loop.
1441       for (int batch = 0; batch < batch_count; ++batch) {
1442         const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1443             static_cast<const Eigen::bfloat16 *>(a.opaque()) +
1444             batch * stride_a);
1445         const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1446             static_cast<const Eigen::bfloat16 *>(b.opaque()) +
1447             batch * stride_b);
1448         auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
1449             static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
1450         TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1451             cublasSgemmEx, stream, true /* = pointer_mode_host */,
1452             CUBLAS_DEFAULT_MATH, AsCublasOperation(transa),
1453             AsCublasOperation(transb), m, n, k,
1454             static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
1455             b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
1456             c_matrix, CUDA_R_16BF, ldc));
1457       }
1458       return port::Status::OK();
1459     }
1460 #endif
1461     case dnn::kHalf: {
1462 #if CUDA_VERSION >= 9010
1463       CudaComputeCapability cc = stream->GetCudaComputeCapability();
1464       if (cc.major >= 5) {
1465         cublasGemmAlgo_t algo =
1466             (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
1467         return DoBlasInternalImpl(
1468             AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
1469             true /* = pointer_mode_host */, math_type,
1470             AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1471             alpha, a.opaque(), CUDA_R_16F, lda, stride_a, b.opaque(),
1472             CUDA_R_16F, ldb, stride_b, beta, c->opaque(), CUDA_R_16F, ldc,
1473             stride_c, batch_count, CUDA_R_32F, algo);
1474       }
1475 #endif
1476       // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
1477       for (int batch = 0; batch < batch_count; ++batch) {
1478         const auto *a_matrix = reinterpret_cast<const __half *>(
1479             static_cast<const Eigen::half *>(a.opaque()) + batch * stride_a);
1480         const auto *b_matrix = reinterpret_cast<const __half *>(
1481             static_cast<const Eigen::half *>(b.opaque()) + batch * stride_b);
1482         auto *c_matrix = reinterpret_cast<__half *>(
1483             static_cast<Eigen::half *>(c->opaque()) + batch * stride_c);
1484         TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1485             cublasSgemmEx, stream, true /* = pointer_mode_host */,
1486             CUBLAS_DEFAULT_MATH, AsCublasOperation(transa),
1487             AsCublasOperation(transb), m, n, k,
1488             static_cast<const float *>(alpha), a_matrix, SE_CUDA_DATA_HALF, lda,
1489             b_matrix, SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
1490             c_matrix, SE_CUDA_DATA_HALF, ldc));
1491       }
1492       return ::tensorflow::OkStatus();
1493     }
1494     case dnn::kFloat: {
1495       return DoBlasInternalImpl(
1496           cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
1497           math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1498           k, static_cast<const float *>(alpha),
1499           static_cast<const float *>(a.opaque()), lda, stride_a,
1500           static_cast<const float *>(b.opaque()), ldb, stride_b,
1501           static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
1502           ldc, stride_c, batch_count);
1503     }
1504     case dnn::kDouble:
1505       return DoBlasInternalImpl(
1506           cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
1507           math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1508           k, static_cast<const double *>(alpha),
1509           static_cast<const double *>(a.opaque()), lda, stride_a,
1510           static_cast<const double *>(b.opaque()), ldb, stride_b,
1511           static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
1512           ldc, stride_c, batch_count);
1513     case dnn::kComplexFloat: {
1514       GpuComplexType cb_alpha =
1515           GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
1516       GpuComplexType cb_beta =
1517           GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
1518       return DoBlasInternalImpl(
1519           cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
1520           math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1521           k, GpuComplex(&cb_alpha),
1522           static_cast<const GpuComplexType *>(a.opaque()), lda, stride_a,
1523           static_cast<const GpuComplexType *>(b.opaque()), ldb, stride_b,
1524           GpuComplex(&cb_beta), static_cast<GpuComplexType *>(c->opaque()), ldc,
1525           stride_c, batch_count);
1526     }
1527     case dnn::kComplexDouble: {
1528       GpuDoubleComplexType cb_alpha =
1529           GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
1530       GpuDoubleComplexType cb_beta =
1531           GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
1532       return DoBlasInternalImpl(
1533           cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
1534           math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1535           k, GpuComplex(&cb_alpha),
1536           static_cast<const GpuDoubleComplexType *>(a.opaque()), lda, stride_a,
1537           static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, stride_b,
1538           GpuComplex(&cb_beta),
1539           static_cast<GpuDoubleComplexType *>(c->opaque()), ldc, stride_c,
1540           batch_count);
1541     }
1542     default:
1543       return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
1544                                               blas::DataTypeString(dtype)));
1545   }
1546 }
1547 
1548 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1549                           blas::UpperLower uplo, blas::Transpose transa,
1550                           blas::Diagonal diag, uint64_t m, uint64 n,
1551                           float alpha, const DeviceMemory<float> &a, int lda,
1552                           DeviceMemory<float> *b, int ldb) {
1553   return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
1554                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
1555                         AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1556                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
1557 }
1558 
1559 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1560                           blas::UpperLower uplo, blas::Transpose transa,
1561                           blas::Diagonal diag, uint64_t m, uint64 n,
1562                           double alpha, const DeviceMemory<double> &a, int lda,
1563                           DeviceMemory<double> *b, int ldb) {
1564   return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
1565                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
1566                         AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1567                         &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
1568 }
1569 
1570 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1571                           blas::UpperLower uplo, blas::Transpose transa,
1572                           blas::Diagonal diag, uint64_t m, uint64 n,
1573                           std::complex<float> alpha,
1574                           const DeviceMemory<std::complex<float>> &a, int lda,
1575                           DeviceMemory<std::complex<float>> *b, int ldb) {
1576   auto cb_alpha = GpuComplexValue(alpha);
1577   return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
1578                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
1579                         AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1580                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1581                         GpuComplex(GpuMemoryMutable(b)), ldb);
1582 }
1583 
1584 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1585                           blas::UpperLower uplo, blas::Transpose transa,
1586                           blas::Diagonal diag, uint64_t m, uint64 n,
1587                           std::complex<double> alpha,
1588                           const DeviceMemory<std::complex<double>> &a, int lda,
1589                           DeviceMemory<std::complex<double>> *b, int ldb) {
1590   auto cb_alpha = GpuComplexValue(alpha);
1591   return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
1592                         CUDABlasSide(side), CUDABlasUpperLower(uplo),
1593                         AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1594                         GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1595                         GpuComplex(GpuMemoryMutable(b)), ldb);
1596 }
1597 
1598 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1599                                  blas::UpperLower uplo, blas::Transpose transa,
1600                                  blas::Diagonal diag, uint64_t m, uint64 n,
1601                                  float alpha, const DeviceMemory<float *> &as,
1602                                  int lda, DeviceMemory<float *> *bs, int ldb,
1603                                  int batch_count) {
1604   return DoBlasInternal(cublasStrsmBatched, stream,
1605                         true /* = pointer_mode_host */, CUDABlasSide(side),
1606                         CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1607                         CUDABlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1608                         lda, GpuMemoryMutable(bs), ldb, batch_count);
1609 }
1610 
1611 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1612                                  blas::UpperLower uplo, blas::Transpose transa,
1613                                  blas::Diagonal diag, uint64_t m, uint64 n,
1614                                  double alpha, const DeviceMemory<double *> &as,
1615                                  int lda, DeviceMemory<double *> *bs, int ldb,
1616                                  int batch_count) {
1617   return DoBlasInternal(cublasDtrsmBatched, stream,
1618                         true /* = pointer_mode_host */, CUDABlasSide(side),
1619                         CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1620                         CUDABlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1621                         lda, GpuMemoryMutable(bs), ldb, batch_count);
1622 }
1623 
1624 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1625                                  blas::UpperLower uplo, blas::Transpose transa,
1626                                  blas::Diagonal diag, uint64_t m, uint64 n,
1627                                  std::complex<float> alpha,
1628                                  const DeviceMemory<std::complex<float> *> &as,
1629                                  int lda,
1630                                  DeviceMemory<std::complex<float> *> *bs,
1631                                  int ldb, int batch_count) {
1632   auto cb_alpha = GpuComplexValue(alpha);
1633   return DoBlasInternal(
1634       cublasCtrsmBatched, stream, true /* = pointer_mode_host */,
1635       CUDABlasSide(side), CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1636       CUDABlasDiagonal(diag), m, n, &cb_alpha,
1637       reinterpret_cast<float2 *const *>(GpuMemory(as)), lda,
1638       reinterpret_cast<float2 **>(GpuMemoryMutable(bs)), ldb, batch_count);
1639 }
1640 
1641 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1642                                  blas::UpperLower uplo, blas::Transpose transa,
1643                                  blas::Diagonal diag, uint64_t m, uint64 n,
1644                                  std::complex<double> alpha,
1645                                  const DeviceMemory<std::complex<double> *> &as,
1646                                  int lda,
1647                                  DeviceMemory<std::complex<double> *> *bs,
1648                                  int ldb, int batch_count) {
1649   auto cb_alpha = GpuComplexValue(alpha);
1650   return DoBlasInternal(
1651       cublasZtrsmBatched, stream, true /* = pointer_mode_host */,
1652       CUDABlasSide(side), CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1653       CUDABlasDiagonal(diag), m, n, &cb_alpha,
1654       reinterpret_cast<double2 *const *>(GpuMemory(as)), lda,
1655       reinterpret_cast<double2 **>(GpuMemoryMutable(bs)), ldb, batch_count);
1656 }
1657 
1658 port::Status CUDABlas::GetVersion(std::string *version) {
1659   absl::MutexLock lock(&mu_);
1660 
1661   int v;
1662   auto status = cublasGetVersion(blas_, &v);
1663   if (status != CUBLAS_STATUS_SUCCESS) {
1664     return port::InternalError(ToString(status));
1665   }
1666   *version = std::to_string(v);
1667   return ::tensorflow::OkStatus();
1668 }
1669 
1670 void initialize_cublas() {
1671   port::Status status =
1672       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
1673           kCudaPlatformId, kCuBlasPlugin, "cuBLAS",
1674           [](::stream_executor::internal::StreamExecutorInterface *parent)
1675               -> blas::BlasSupport * {
1676             gpu::GpuExecutor *cuda_executor =
1677                 dynamic_cast<gpu::GpuExecutor *>(parent);
1678             if (cuda_executor == nullptr) {
1679               LOG(ERROR)
1680                   << "Attempting to initialize an instance of the cuBLAS "
1681                   << "support library with a non-CUDA StreamExecutor";
1682               return nullptr;
1683             }
1684 
1685             CUDABlas *blas = new CUDABlas(cuda_executor);
1686             if (!blas->Init()) {
1687               // Note: Init() will log a more specific error.
1688               delete blas;
1689               return nullptr;
1690             }
1691             return blas;
1692           });
1693 
1694   if (!status.ok()) {
1695     LOG(ERROR) << "Unable to register cuBLAS factory: "
1696                << status.error_message();
1697   }
1698 
1699   PluginRegistry::Instance()->SetDefaultFactory(
1700       cuda::kCudaPlatformId, PluginKind::kBlas, kCuBlasPlugin);
1701 }
1702 
1703 }  // namespace cuda
1704 }  // namespace stream_executor
1705 
1706 REGISTER_MODULE_INITIALIZER(register_cublas,
1707                             { stream_executor::cuda::initialize_cublas(); });
1708