1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "third_party/gpus/cuda/include/cublas_v2.h"
17 #include "third_party/gpus/cuda/include/cuda.h"
18
19 #define SE_CUDA_DATA_HALF CUDA_R_16F
20
21 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h"
22
23 // Both Eigen Half.h and CUDA cuda_fp16.h provide similar typedef for __half. As
24 // such, there are two ways to get the typedef for __half:
25 //
26 // (1) Includes cuda_fp16.h and defines EIGEN_HAS_CUDA_FP16.
27 // (2) Neither includes cuda_fp16.h nor defines EIGEN_HAS_CUDA_FP16.
28 //
29 // Due to issue b/73793421, when the first approach is used and NVCC is used to
30 // compile this file, NVCC will complain duplicated definition for
31 // EIGEN_HAS_CUDA_FP16. On the other hand, when the second approach is used and
32 // clang is used to compile this file, clang will not understand __half
33 // due to missing the definition and macro EIGEN_HAS_CUDA_FP16.
34 //
35 // Because this file may be compiled with CLANG but will never be compiled with
36 // NVCC, we choose the first approach for CUDA < 9.0. For CUDA >= 9.0, we have
37 // to use the second approach because the data member in the __half defined
38 // by CUDA > 9.0 is `__x` while Eigen expects it to be `x`.
39 //
40 // TODO(b/73793421): Remove the following code block to switch to the second
41 // approach when the issue is fixed.
42 #if CUDA_VERSION < 9000
43 #include "third_party/gpus/cuda/include/cuda_fp16.h"
44 #define EIGEN_HAS_CUDA_FP16
45 #endif
46
47 #include <complex>
48
49 #include "absl/strings/str_cat.h"
50 #include "absl/strings/str_format.h"
51 #include "third_party/eigen3/Eigen/Core"
52 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_activation.h"
53 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h"
54 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.h"
55 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_helpers.h"
56 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h"
57 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_stream.h"
58 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_timer.h"
59 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
60 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_executor.h"
61 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h"
62 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
63 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h"
64 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_types.h"
65 #include "tensorflow/compiler/xla/stream_executor/lib/initialize.h"
66 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
67 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
68 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
69 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
70 #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h"
71 #include "tensorflow/compiler/xla/stream_executor/stream_executor.h"
72 #include "tensorflow/core/platform/tensor_float_32_utils.h"
73
74 namespace stream_executor {
75 namespace cuda {
76
77 using gpu::AsGpuStream;
78 using gpu::AsGpuStreamValue;
79 using gpu::GpuComplex;
80 using gpu::GpuComplexT;
81 using gpu::GpuComplexType;
82 using gpu::GpuComplexValue;
83 using gpu::GpuDoubleComplexType;
84 using gpu::GpuExecutor;
85 using gpu::GpuMemory;
86 using gpu::GpuMemoryMutable;
87 using gpu::GpuTimer;
88 using gpu::GpuTimerDeleter;
89
90 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
91
92 // cuBLAS has interfaces that permit pointers to be passed from either the host
93 // memory space or the device memory space; however, you must instruct it as to
94 // which address space those pointers are in with cublasSetPointerMode.
95 //
96 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
97 // you are about to perform in a given scope.
98 //
99 // The prior cuBLAS pointer mode is retained and restored when this object goes
100 // out of scope.
101 class ScopedCublasPointerMode {
102 public:
103 // Note that, because the setting of the cublas pointer mode is fallible,
104 // construction of this scoped datatype must be paired with a call to
105 // Init().
106 //
107 // Parameters:
108 // handle: The cublas library handle to act upon in setting the pointer mode.
ScopedCublasPointerMode(cublasHandle_t handle)109 explicit ScopedCublasPointerMode(cublasHandle_t handle)
110 : handle_(handle), ok_(false) {}
111
112 // Attempts the switch to the requested scoped pointer mode, new_mode.
113 //
114 // Note that when false is returned, an appropriate error has already been
115 // logged.
Init(cublasPointerMode_t new_mode)116 bool Init(cublasPointerMode_t new_mode) {
117 cublasStatus_t ret = cublasGetPointerMode(handle_, &old_mode_);
118 if (ret != CUBLAS_STATUS_SUCCESS) {
119 LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
120 return ok_ = false;
121 }
122
123 ret = cublasSetPointerMode(handle_, new_mode);
124 if (ret != CUBLAS_STATUS_SUCCESS) {
125 LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
126 return ok_ = false;
127 }
128
129 return ok_ = true;
130 }
131
132 // Switches back to the prior pointer mode, if the switch operation was
133 // successful in the first place.
~ScopedCublasPointerMode()134 ~ScopedCublasPointerMode() {
135 if (ok_) {
136 cublasStatus_t ret = cublasSetPointerMode(handle_, old_mode_);
137 if (ret != CUBLAS_STATUS_SUCCESS) {
138 LOG(ERROR) << "failed to set former cublas pointer mode: "
139 << ToString(ret);
140 }
141 }
142 }
143
144 private:
145 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
146 cublasPointerMode_t old_mode_; // Prior cuBLAS pointer mode, to be restored.
147 bool ok_; // Whether the change was successful.
148 };
149
150 #if CUDA_VERSION >= 9000
151 // cuBLAS has interfaces that permit computations to use the Volta hardware.
152 // This must be enabled via the cublasGet/SetMathMode APIs.
153 //
154 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
155 // you are about to perform in a given scope.
156 //
157 // The prior cuBLAS math mode is retained and restored when this object goes
158 // out of scope.
159 class ScopedCublasMathMode {
160 public:
161 // Note that, because the setting of the cublas math mode is fallible,
162 // construction of this scoped datatype must be paired with a call to
163 // Init().
164 //
165 // Parameters:
166 // handle: The cublas library handle to act upon in setting the math mode.
ScopedCublasMathMode(cublasHandle_t handle)167 explicit ScopedCublasMathMode(cublasHandle_t handle)
168 : handle_(handle), ok_(false) {}
169
170 // Attempts the switch to the requested scoped math mode, new_mode.
171 //
172 // Note that when false is returned, an appropriate error has already been
173 // logged.
Init(cublasMath_t new_mode)174 bool Init(cublasMath_t new_mode) {
175 cublasStatus_t ret = cublasGetMathMode(handle_, &old_mode_);
176 if (ret != CUBLAS_STATUS_SUCCESS) {
177 LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
178 return ok_ = false;
179 }
180
181 ret = cublasSetMathMode(handle_, new_mode);
182 if (ret != CUBLAS_STATUS_SUCCESS) {
183 LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
184 return ok_ = false;
185 }
186 return ok_ = true;
187 }
188
189 // Switches back to the prior math mode, if the switch operation was
190 // successful in the first place.
~ScopedCublasMathMode()191 ~ScopedCublasMathMode() {
192 if (ok_) {
193 cublasStatus_t ret = cublasSetMathMode(handle_, old_mode_);
194 if (ret != CUBLAS_STATUS_SUCCESS) {
195 LOG(ERROR) << "failed to set former cublas math mode: "
196 << ToString(ret);
197 }
198 }
199 }
200
201 private:
202 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest.
203 cublasMath_t old_mode_; // Prior cuBLAS math mode, to be restored.
204 bool ok_; // Whether the change was successful.
205 };
206 #endif // CUDA_VERSION >= 9000
207
208 static const char *const kCublasNotInitializedExplanation =
209 "Failure to initialize cublas may be due to OOM (cublas needs some free "
210 "memory when you initialize it, and your deep-learning framework may have "
211 "preallocated more than its fair share), or may be because this binary was "
212 "not built with support for the GPU in your machine.";
213
Init()214 bool CUDABlas::Init() {
215 gpu::ScopedActivateExecutorContext sac{parent_};
216 cublasStatus_t ret = cublasCreate(&blas_);
217 if (ret != CUBLAS_STATUS_SUCCESS) {
218 LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
219 if (ret == CUBLAS_STATUS_NOT_INITIALIZED) {
220 LOG(ERROR) << kCublasNotInitializedExplanation;
221 }
222 return false;
223 }
224
225 #if CUDA_VERSION >= 11000
226 if (!blas_lt_.Init().ok()) {
227 LOG(ERROR) << kCublasNotInitializedExplanation;
228 return false;
229 }
230 #endif // CUDA_VERSION >= 11000
231
232 return true;
233 }
234
CUDABlas(gpu::GpuExecutor * parent)235 CUDABlas::CUDABlas(gpu::GpuExecutor *parent)
236 : parent_(CHECK_NOTNULL(parent)),
237 blas_(nullptr)
238 #if CUDA_VERSION >= 11000
239 ,
240 blas_lt_(parent)
241 #endif
242 {
243 }
244
~CUDABlas()245 CUDABlas::~CUDABlas() {
246 if (blas_ != nullptr) {
247 gpu::ScopedActivateExecutorContext sac{parent_};
248 cublasDestroy(blas_);
249 }
250 }
251
SetStream(Stream * stream)252 bool CUDABlas::SetStream(Stream *stream) {
253 CHECK(stream != nullptr);
254 CHECK(AsGpuStreamValue(stream) != nullptr);
255 CHECK(blas_ != nullptr);
256 gpu::ScopedActivateExecutorContext sac{parent_};
257 cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream));
258 if (ret != CUBLAS_STATUS_SUCCESS) {
259 LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
260 return false;
261 }
262
263 return true;
264 }
265
CUDAStream(Stream * stream)266 cudaStream_t CUDABlas::CUDAStream(Stream *stream) {
267 CHECK(stream != nullptr);
268 CHECK(AsGpuStreamValue(stream) != nullptr);
269 gpu::ScopedActivateExecutorContext sac{parent_};
270 return AsGpuStreamValue(stream);
271 }
272
273 namespace {
274
275 // Helper functions transforming blas arguments into cuBLAS arguments.
276
CUDABlasUpperLower(blas::UpperLower uplo)277 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
278 switch (uplo) {
279 case blas::UpperLower::kUpper:
280 return CUBLAS_FILL_MODE_UPPER;
281 case blas::UpperLower::kLower:
282 return CUBLAS_FILL_MODE_LOWER;
283 default:
284 LOG(FATAL) << "Invalid value of blas::UpperLower.";
285 }
286 }
287
CUDABlasDiagonal(blas::Diagonal diag)288 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
289 switch (diag) {
290 case blas::Diagonal::kUnit:
291 return CUBLAS_DIAG_UNIT;
292 case blas::Diagonal::kNonUnit:
293 return CUBLAS_DIAG_NON_UNIT;
294 default:
295 LOG(FATAL) << "Invalid value of blas::Diagonal.";
296 }
297 }
298
CUDABlasSide(blas::Side side)299 cublasSideMode_t CUDABlasSide(blas::Side side) {
300 switch (side) {
301 case blas::Side::kLeft:
302 return CUBLAS_SIDE_LEFT;
303 case blas::Side::kRight:
304 return CUBLAS_SIDE_RIGHT;
305 default:
306 LOG(FATAL) << "Invalid value of blas::Side.";
307 }
308 }
309
310 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
311 // cudaDataType_t (e.g. CUDA_R_32F).
312 //
313 // These are used to build the argument type and computation type args to
314 // cublasGemmEx.
315 template <typename T>
316 struct CUDADataType;
317
318 template <>
319 struct CUDADataType<Eigen::half> {
320 static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
321 };
322
323 template <>
324 struct CUDADataType<std::complex<Eigen::half>> {
325 static constexpr cudaDataType_t type = CUDA_C_16F;
326 };
327
328 template <>
329 struct CUDADataType<float> {
330 static constexpr cudaDataType_t type = CUDA_R_32F;
331 };
332
333 template <>
334 struct CUDADataType<std::complex<float>> {
335 static constexpr cudaDataType_t type = CUDA_C_32F;
336 };
337
338 template <>
339 struct CUDADataType<double> {
340 static constexpr cudaDataType_t type = CUDA_R_64F;
341 };
342
343 template <>
344 struct CUDADataType<std::complex<double>> {
345 static constexpr cudaDataType_t type = CUDA_C_64F;
346 };
347
348 template <>
349 struct CUDADataType<int> {
350 static constexpr cudaDataType_t type = CUDA_R_32I;
351 };
352
353 template <>
354 struct CUDADataType<int8> {
355 static constexpr cudaDataType_t type = CUDA_R_8I;
356 };
357
358 template <>
359 struct CUDADataType<std::complex<int8>> {
360 static constexpr cudaDataType_t type = CUDA_C_8I;
361 };
362
363 template <>
364 struct CUDADataType<uint8> {
365 static constexpr cudaDataType_t type = CUDA_R_8U;
366 };
367
368 template <>
369 struct CUDADataType<std::complex<uint8>> {
370 static constexpr cudaDataType_t type = CUDA_C_8U;
371 };
372
373 } // namespace
374
375 template <typename FuncT, typename... Args>
DoBlasInternalImpl(FuncT cublas_func,Stream * stream,bool pointer_mode_host,cublasMath_t math_type,Args...args)376 port::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
377 bool pointer_mode_host,
378 cublasMath_t math_type,
379 Args... args) {
380 absl::MutexLock lock(&mu_);
381
382 CHECK(blas_ != nullptr);
383 if (!SetStream(stream)) {
384 return port::InternalError("Failed setting stream");
385 }
386
387 #if CUDA_VERSION >= 9000
388 ScopedCublasMathMode math_mode{blas_};
389 #if CUBLAS_VER_MAJOR >= 11
390 if (math_type == CUBLAS_TF32_TENSOR_OP_MATH &&
391 tensorflow::tensor_float_32_execution_enabled()) {
392 #else
393 if (math_type == CUBLAS_TENSOR_OP_MATH) {
394 #endif
395 if (!math_mode.Init(math_type)) {
396 return port::InternalError("Failed initializing math mode");
397 }
398 }
399 #endif
400
401 gpu::ScopedActivateExecutorContext sac{parent_};
402 ScopedCublasPointerMode pointer_mode{blas_};
403 if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
404 : CUBLAS_POINTER_MODE_DEVICE)) {
405 return port::InternalError("Failed setting error mode");
406 }
407 cublasStatus_t ret = cublas_func(blas_, args...);
408 if (ret == CUBLAS_STATUS_SUCCESS) {
409 return ::tensorflow::OkStatus();
410 }
411 return port::InternalError(ToString(ret));
412 }
413
414 // cublas_func may be overloaded, so we need to figure out which one we really
415 // need to call based on the args. One way to do it is to wrap it in lambda.
416 #define AS_LAMBDA(func) \
417 [](auto &&...args) -> decltype(func( \
418 std::forward<decltype(args)>(args)...)) { \
419 return func(std::forward<decltype(args)>(args)...); \
420 }
421
422 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, float alpha,
423 const DeviceMemory<float> &x, int incx,
424 DeviceMemory<float> *y, int incy) {
425 return DoBlasInternal(cublasSaxpy, stream, true /* = pointer_mode_host */,
426 elem_count, &alpha, GpuMemory(x), incx,
427 GpuMemoryMutable(y), incy);
428 }
429
430 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count, double alpha,
431 const DeviceMemory<double> &x, int incx,
432 DeviceMemory<double> *y, int incy) {
433 return DoBlasInternal(cublasDaxpy, stream, true /* = pointer_mode_host */,
434 elem_count, &alpha, GpuMemory(x), incx,
435 GpuMemoryMutable(y), incy);
436 }
437
438 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
439 std::complex<float> alpha,
440 const DeviceMemory<std::complex<float>> &x, int incx,
441 DeviceMemory<std::complex<float>> *y, int incy) {
442 auto cb_alpha = GpuComplexValue(alpha);
443 return DoBlasInternal(cublasCaxpy, stream, true /* = pointer_mode_host */,
444 elem_count, GpuComplex(&cb_alpha),
445 GpuComplex(GpuMemory(x)), incx,
446 GpuComplex(GpuMemoryMutable(y)), incy);
447 }
448
449 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64_t elem_count,
450 std::complex<double> alpha,
451 const DeviceMemory<std::complex<double>> &x, int incx,
452 DeviceMemory<std::complex<double>> *y, int incy) {
453 auto cb_alpha = GpuComplexValue(alpha);
454 return DoBlasInternal(cublasZaxpy, stream, true /* = pointer_mode_host */,
455 elem_count, GpuComplex(&cb_alpha),
456 GpuComplex(GpuMemory(x)), incx,
457 GpuComplex(GpuMemoryMutable(y)), incy);
458 }
459
460 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
461 const DeviceMemory<float> &x, int incx,
462 DeviceMemory<float> *y, int incy) {
463 return DoBlasInternal(cublasScopy, stream, true /* = pointer_mode_host */,
464 elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
465 incy);
466 }
467
468 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
469 const DeviceMemory<double> &x, int incx,
470 DeviceMemory<double> *y, int incy) {
471 return DoBlasInternal(cublasDcopy, stream, true /* = pointer_mode_host */,
472 elem_count, GpuMemory(x), incx, GpuMemoryMutable(y),
473 incy);
474 }
475
476 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
477 const DeviceMemory<std::complex<float>> &x, int incx,
478 DeviceMemory<std::complex<float>> *y, int incy) {
479 return DoBlasInternal(cublasCcopy, stream, true /* = pointer_mode_host */,
480 elem_count, GpuComplex(GpuMemory(x)), incx,
481 GpuComplex(GpuMemoryMutable(y)), incy);
482 }
483
484 bool CUDABlas::DoBlasCopy(Stream *stream, uint64_t elem_count,
485 const DeviceMemory<std::complex<double>> &x, int incx,
486 DeviceMemory<std::complex<double>> *y, int incy) {
487 return DoBlasInternal(cublasZcopy, stream, true /* = pointer_mode_host */,
488 elem_count, GpuComplex(GpuMemory(x)), incx,
489 GpuComplex(GpuMemoryMutable(y)), incy);
490 }
491
492 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
493 DeviceMemory<float> *x, int incx) {
494 return DoBlasInternal(cublasSscal, stream, true /* = pointer_mode_host */,
495 elem_count, &alpha, GpuMemoryMutable(x), incx);
496 }
497
498 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
499 DeviceMemory<double> *x, int incx) {
500 return DoBlasInternal(cublasDscal, stream, true /* = pointer_mode_host */,
501 elem_count, &alpha, GpuMemoryMutable(x), incx);
502 }
503
504 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, float alpha,
505 DeviceMemory<std::complex<float>> *x, int incx) {
506 return DoBlasInternal(cublasCsscal, stream, true /* = pointer_mode_host */,
507 elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
508 incx);
509 }
510
511 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count, double alpha,
512 DeviceMemory<std::complex<double>> *x, int incx) {
513 return DoBlasInternal(cublasZdscal, stream, true /* = pointer_mode_host */,
514 elem_count, &alpha, GpuComplex(GpuMemoryMutable(x)),
515 incx);
516 }
517
518 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count,
519 std::complex<float> alpha,
520 DeviceMemory<std::complex<float>> *x, int incx) {
521 auto cb_alpha = GpuComplexValue(alpha);
522 return DoBlasInternal(cublasCscal, stream, true /* = pointer_mode_host */,
523 elem_count, GpuComplex(&cb_alpha),
524 GpuComplex(GpuMemoryMutable(x)), incx);
525 }
526
527 bool CUDABlas::DoBlasScal(Stream *stream, uint64_t elem_count,
528 std::complex<double> alpha,
529 DeviceMemory<std::complex<double>> *x, int incx) {
530 auto cb_alpha = GpuComplexValue(alpha);
531 return DoBlasInternal(cublasZscal, stream, true /* = pointer_mode_host */,
532 elem_count, GpuComplex(&cb_alpha),
533 GpuComplex(GpuMemoryMutable(x)), incx);
534 }
535
536 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
537 uint64_t n, float alpha, const DeviceMemory<float> &a,
538 int lda, const DeviceMemory<float> &x, int incx,
539 float beta, DeviceMemory<float> *y, int incy) {
540 return DoBlasInternal(cublasSgemv, stream, true /* = pointer_mode_host */,
541 AsCublasOperation(trans), m, n, &alpha, GpuMemory(a),
542 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
543 incy);
544 }
545
546 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
547 uint64_t n, double alpha,
548 const DeviceMemory<double> &a, int lda,
549 const DeviceMemory<double> &x, int incx, double beta,
550 DeviceMemory<double> *y, int incy) {
551 return DoBlasInternal(cublasDgemv, stream, true /* = pointer_mode_host */,
552 AsCublasOperation(trans), m, n, &alpha, GpuMemory(a),
553 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
554 incy);
555 }
556
557 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
558 uint64_t n, std::complex<float> alpha,
559 const DeviceMemory<std::complex<float>> &a, int lda,
560 const DeviceMemory<std::complex<float>> &x, int incx,
561 std::complex<float> beta,
562 DeviceMemory<std::complex<float>> *y, int incy) {
563 auto cb_alpha = GpuComplexValue(alpha);
564 auto cb_beta = GpuComplexValue(beta);
565 return DoBlasInternal(cublasCgemv, stream, true /* = pointer_mode_host */,
566 AsCublasOperation(trans), m, n, GpuComplex(&cb_alpha),
567 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
568 incx, GpuComplex(&cb_beta),
569 GpuComplex(GpuMemoryMutable(y)), incy);
570 }
571
572 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m,
573 uint64_t n, std::complex<double> alpha,
574 const DeviceMemory<std::complex<double>> &a, int lda,
575 const DeviceMemory<std::complex<double>> &x, int incx,
576 std::complex<double> beta,
577 DeviceMemory<std::complex<double>> *y, int incy) {
578 auto cb_alpha = GpuComplexValue(alpha);
579 auto cb_beta = GpuComplexValue(beta);
580 return DoBlasInternal(cublasZgemv, stream, true /* = pointer_mode_host */,
581 AsCublasOperation(trans), m, n, GpuComplex(&cb_alpha),
582 GpuComplex(GpuMemory(a)), lda, GpuComplex(GpuMemory(x)),
583 incx, GpuComplex(&cb_beta),
584 GpuComplex(GpuMemoryMutable(y)), incy);
585 }
586
587 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
588 uint64_t k, float alpha, const DeviceMemory<float> &a,
589 int lda, const DeviceMemory<float> &x, int incx,
590 float beta, DeviceMemory<float> *y, int incy) {
591 return DoBlasInternal(cublasSsbmv, stream, true /* = pointer_mode_host */,
592 CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
593 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
594 incy);
595 }
596
597 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64_t n,
598 uint64_t k, double alpha,
599 const DeviceMemory<double> &a, int lda,
600 const DeviceMemory<double> &x, int incx, double beta,
601 DeviceMemory<double> *y, int incy) {
602 return DoBlasInternal(cublasDsbmv, stream, true /* = pointer_mode_host */,
603 CUDABlasUpperLower(uplo), n, k, &alpha, GpuMemory(a),
604 lda, GpuMemory(x), incx, &beta, GpuMemoryMutable(y),
605 incy);
606 }
607
608 port::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
609 blas::Transpose transb, uint64_t m, uint64 n,
610 uint64_t k, blas::DataType dtype,
611 const void *alpha, const DeviceMemoryBase &a,
612 int lda, const DeviceMemoryBase &b, int ldb,
613 const void *beta, DeviceMemoryBase *c,
614 int ldc, blas::ComputePrecision precision) {
615 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
616
617 #if CUDA_VERSION < 11000
618 if (dtype == blas::DataType::kHalf) {
619 math_type = CUBLAS_TENSOR_OP_MATH;
620 }
621 #else
622 if (dtype == blas::DataType::kFloat) {
623 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
624 if (stream->GetCudaComputeCapability().IsAtLeast(
625 CudaComputeCapability::AMPERE)) {
626 // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
627 // well tested.
628 if (tensorflow::tensor_float_32_execution_enabled()) {
629 LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
630 "multiplication. This will only be logged "
631 "once.";
632 }
633 }
634 if (precision > blas::kDefaultComputePrecision) {
635 math_type = CUBLAS_DEFAULT_MATH;
636 }
637 }
638 #endif
639
640 // TODO(cheshire): Return an error instead.
641 // TODO(cheshire): Why are these checked only for `half` and `float`?
642 if (dtype == blas::DataType::kHalf || dtype == blas::DataType::kFloat) {
643 if (transa == blas::Transpose::kNoTranspose) {
644 if (lda < static_cast<int64_t>(m)) {
645 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
646 "precondition violation";
647 }
648 } else {
649 if (lda < static_cast<int64_t>(k)) {
650 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
651 << ") (transpose case); precondition violation";
652 }
653 }
654 if (transb == blas::Transpose::kNoTranspose) {
655 if (ldb < static_cast<int64_t>(k)) {
656 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
657 << ") (no transpose case); precondition violation";
658 }
659 } else {
660 if (ldb < static_cast<int64_t>(n)) {
661 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
662 "precondition violation";
663 }
664 }
665 }
666
667 VLOG(1) << absl::StrFormat(
668 "doing cuBLAS SGEMM: at=%d bt=%d m=%u n=%u "
669 "k=%u alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
670 "c=%p ldc=%d",
671 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
672 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
673
674 switch (dtype) {
675 case blas::DataType::kHalf: {
676 #if CUDA_VERSION < 7050
677 return port::InternalError(
678 "fp16 sgemm is not implemented in this cuBLAS version "
679 "(need at least CUDA 7.5)");
680 #endif
681
682 return DoBlasInternalImpl(
683 cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
684 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
685 static_cast<const float *>(alpha), a.opaque(), SE_CUDA_DATA_HALF, lda,
686 b.opaque(), SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
687 c->opaque(), SE_CUDA_DATA_HALF, ldc);
688 }
689 #if CUDA_VERSION > 11000
690 case blas::DataType::kBF16: {
691 return DoBlasInternalImpl(
692 cublasSgemmEx, stream, true /* = pointer_mode_host */, math_type,
693 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
694 static_cast<const float *>(alpha), a.opaque(), CUDA_R_16BF, lda,
695 b.opaque(), CUDA_R_16BF, ldb, static_cast<const float *>(beta),
696 c->opaque(), CUDA_R_16BF, ldc);
697 }
698 #endif
699 case dnn::kFloat:
700 return DoBlasInternalImpl(
701 cublasSgemm, stream, true /* = pointer_mode_host */, math_type,
702 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
703 static_cast<const float *>(alpha),
704 static_cast<const float *>(a.opaque()), lda,
705 static_cast<const float *>(b.opaque()), ldb,
706 static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
707 ldc);
708 case dnn::kDouble:
709 return DoBlasInternalImpl(
710 cublasDgemm, stream, true /* = pointer_mode_host */, math_type,
711 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
712 static_cast<const double *>(alpha),
713 static_cast<const double *>(a.opaque()), lda,
714 static_cast<const double *>(b.opaque()), ldb,
715 static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
716 ldc);
717 case dnn::kComplexFloat: {
718 GpuComplexType cb_alpha =
719 GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
720 GpuComplexType cb_beta =
721 GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
722 return DoBlasInternalImpl(
723 cublasCgemm, stream, true /* = pointer_mode_host */, math_type,
724 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
725 &cb_alpha, static_cast<const GpuComplexType *>(a.opaque()), lda,
726 static_cast<const GpuComplexType *>(b.opaque()), ldb, &cb_beta,
727 static_cast<GpuComplexType *>(c->opaque()), ldc);
728 }
729 case dnn::kComplexDouble: {
730 GpuDoubleComplexType cb_alpha =
731 GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
732 GpuDoubleComplexType cb_beta =
733 GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
734 return DoBlasInternalImpl(
735 cublasZgemm, stream, true /* = pointer_mode_host */, math_type,
736 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
737 &cb_alpha, static_cast<const GpuDoubleComplexType *>(a.opaque()), lda,
738 static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, &cb_beta,
739 static_cast<GpuDoubleComplexType *>(c->opaque()), ldc);
740 }
741 default:
742 return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
743 blas::DataTypeString(dtype)));
744 }
745 }
746
747 bool CUDABlas::DoBlasGemvWithProfiling(
748 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, float alpha,
749 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
750 int incx, float beta, DeviceMemory<float> *y, int incy,
751 blas::ProfileResult *output_profile_result) {
752 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
753 incx, beta, y, incy,
754 output_profile_result);
755 }
756
757 bool CUDABlas::DoBlasGemvWithProfiling(
758 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, double alpha,
759 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
760 int incx, double beta, DeviceMemory<double> *y, int incy,
761 blas::ProfileResult *output_profile_result) {
762 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
763 incx, beta, y, incy,
764 output_profile_result);
765 }
766
767 bool CUDABlas::DoBlasGemvWithProfiling(
768 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
769 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
770 int lda, const DeviceMemory<std::complex<float>> &x, int incx,
771 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
772 blas::ProfileResult *output_profile_result) {
773 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
774 incx, beta, y, incy,
775 output_profile_result);
776 }
777
778 bool CUDABlas::DoBlasGemvWithProfiling(
779 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n,
780 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
781 int lda, const DeviceMemory<std::complex<double>> &x, int incx,
782 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
783 blas::ProfileResult *output_profile_result) {
784 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
785 incx, beta, y, incy,
786 output_profile_result);
787 }
788
789 bool CUDABlas::DoBlasGemmWithProfiling(
790 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
791 uint64_t n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
792 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
793 DeviceMemory<Eigen::half> *c, int ldc,
794 blas::ProfileResult *output_profile_result) {
795 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
796 lda, b, ldb, beta, c, ldc,
797 output_profile_result);
798 }
799
800 bool CUDABlas::DoBlasGemmWithProfiling(
801 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
802 uint64_t n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
803 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
804 int ldc, blas::ProfileResult *output_profile_result) {
805 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
806 lda, b, ldb, beta, c, ldc,
807 output_profile_result);
808 }
809
810 bool CUDABlas::DoBlasGemmWithProfiling(
811 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
812 uint64_t n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
813 const DeviceMemory<double> &b, int ldb, double beta,
814 DeviceMemory<double> *c, int ldc,
815 blas::ProfileResult *output_profile_result) {
816 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
817 lda, b, ldb, beta, c, ldc,
818 output_profile_result);
819 }
820
821 bool CUDABlas::DoBlasGemmWithProfiling(
822 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
823 uint64_t n, uint64 k, std::complex<float> alpha,
824 const DeviceMemory<std::complex<float>> &a, int lda,
825 const DeviceMemory<std::complex<float>> &b, int ldb,
826 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
827 blas::ProfileResult *output_profile_result) {
828 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
829 lda, b, ldb, beta, c, ldc,
830 output_profile_result);
831 }
832
833 bool CUDABlas::DoBlasGemmWithProfiling(
834 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
835 uint64_t n, uint64 k, std::complex<double> alpha,
836 const DeviceMemory<std::complex<double>> &a, int lda,
837 const DeviceMemory<std::complex<double>> &b, int ldb,
838 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
839 blas::ProfileResult *output_profile_result) {
840 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
841 lda, b, ldb, beta, c, ldc,
842 output_profile_result);
843 }
844
845 template <typename T>
846 bool CUDABlas::DoBlasGemvWithProfilingImpl(
847 Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, const T &alpha,
848 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
849 const T &beta, DeviceMemory<T> *y, int incy,
850 blas::ProfileResult *output_profile_result) {
851 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
852 if (output_profile_result != nullptr) {
853 timer.reset(new GpuTimer(parent_));
854 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
855 return false;
856 }
857 }
858
859 // Call blasGemm
860 bool result =
861 DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
862
863 if (timer != nullptr && result) {
864 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
865 // state.
866 if (!timer->Stop(AsGpuStream(stream))) {
867 return false;
868 }
869 output_profile_result->set_is_valid(true);
870 output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
871 output_profile_result->set_elapsed_time_in_ms(
872 timer->GetElapsedMilliseconds());
873 }
874 return result;
875 }
876
877 template <typename T, typename ParamType>
878 bool CUDABlas::DoBlasGemmWithProfilingImpl(
879 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
880 uint64_t n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
881 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
882 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
883 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
884 if (output_profile_result != nullptr) {
885 timer.reset(new GpuTimer(parent_));
886 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
887 return false;
888 }
889 }
890
891 // Call blasGemm
892 bool result = DoBlasGemm(stream, transa, transb, m, n, k,
893 blas::ToDataType<T>::value, &alpha, a, lda, b, ldb,
894 &beta, c, ldc, blas::kDefaultComputePrecision)
895 .ok();
896
897 if (timer != nullptr && result) {
898 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
899 // state.
900 if (!timer->Stop(AsGpuStream(stream))) {
901 return false;
902 }
903 output_profile_result->set_is_valid(true);
904 output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
905 output_profile_result->set_elapsed_time_in_ms(
906 timer->GetElapsedMilliseconds());
907 }
908 return result;
909 }
910
911 static bool UsesTensorOps(blas::AlgorithmType algo) {
912 #if CUDA_VERSION >= 9000
913 cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
914 return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
915 #else
916 return false;
917 #endif
918 }
919
920 static port::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
921 Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
922 blas::DataType type_b) {
923 if (type_a != type_b) {
924 return port::InternalError("Types of inputs mismatch");
925 }
926
927 // GPUs < sm_50 don't support cublasGemmEx.
928 CudaComputeCapability cc = stream->GetCudaComputeCapability();
929 if (cc.major < 5) {
930 return port::InternalError(absl::StrCat(
931 "sm_", cc.major, " does not support explicit gemm algorithms."));
932 }
933
934 bool algo_uses_tensor_ops = UsesTensorOps(algorithm);
935 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
936 if (algo_uses_tensor_ops) {
937 if (cc.major < 7) {
938 return port::InternalError(absl::StrCat(
939 "Algorithm ", algorithm,
940 " uses tensor ops, but tensor ops are not available in sm", cc.major,
941 "X devices."));
942 } else if (type_a == blas::DataType::kFloat) {
943 #if CUDA_VERSION < 11000
944 return port::InternalError(absl::StrCat(
945 "Algorithm ", algorithm,
946 " uses tensor ops, but tensor ops are not available for fp32"));
947 #else
948 if (cc.major < 8) {
949 return port::InternalError(absl::StrCat(
950 "Algorithm ", algorithm,
951 " uses tensor ops, but tensor ops are not available in sm",
952 cc.major, "X devices for float input types."));
953 } else if (!tensorflow::tensor_float_32_execution_enabled()) {
954 return port::InternalError(absl::StrCat(
955 "Algorithm ", algorithm,
956 " uses tensor ops, but tensor ops are disabled for fp32 inputs"));
957 }
958 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
959 #endif
960 } else if (type_a == blas::DataType::kHalf) {
961 #if CUDA_VERSION < 11000
962 math_type = CUBLAS_TENSOR_OP_MATH;
963 #endif
964 } else {
965 return port::InternalError(
966 absl::StrCat("Algorithm ", algorithm,
967 " uses tensor ops which are not supported for input"));
968 }
969 }
970
971 // Return false if we might be hitting a cuBLAS bug that produces the wrong
972 // result. See nvbugs/2156201, b/79126339.
973 #if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
974 if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
975 std::max({m, n, k}) >= 2097153 && cc_major < 7) {
976 return port::InternalError(
977 "DoBlasGemmWithAlgorithm returning false to work around cudnn "
978 "<9.2 bug with m, n, or k >= 2097153. See b/79126339.");
979 }
980 #endif
981 return math_type;
982 }
983
984 static port::StatusOr<std::unique_ptr<GpuTimer, GpuTimerDeleter>>
985 StartGpuTimerForProfile(Stream *stream, GpuExecutor *executor,
986 blas::ProfileResult *output_profile_result) {
987 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
988 if (output_profile_result) {
989 timer.reset(new GpuTimer(executor));
990 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
991 return port::InternalError(
992 "output_profile_result given, but unable to create a GpuTimer");
993 }
994 }
995 return timer;
996 }
997
998 static port::Status PopulateProfileFromTimer(
999 GpuTimer *timer, blas::AlgorithmType algorithm,
1000 blas::ProfileResult *output_profile_result, Stream *stream) {
1001 if (timer) {
1002 // GpuTimer will CHECK-fail if we Stop() it while the stream is in an error
1003 // state.
1004 if (!timer->Stop(AsGpuStream(stream))) {
1005 return port::InternalError("unable to stop GpuTimer.");
1006 }
1007 output_profile_result->set_is_valid(true);
1008 output_profile_result->set_algorithm(algorithm);
1009 output_profile_result->set_elapsed_time_in_ms(
1010 timer->GetElapsedMilliseconds());
1011 }
1012 return ::tensorflow::OkStatus();
1013 }
1014
1015 port::Status CUDABlas::DoBlasGemmWithAlgorithm(
1016 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1017 uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
1018 blas::DataType type_a, int lda, const DeviceMemoryBase &b,
1019 blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
1020 blas::DataType type_c, int ldc, blas::ComputationType computation_type,
1021 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1022 TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
1023 GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
1024
1025 TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
1026 stream, parent_, output_profile_result));
1027
1028 // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
1029 // we do the following compile-time check on the default value:
1030 static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
1031
1032 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1033 AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true, math_type,
1034 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k, alpha,
1035 a.opaque(), AsCudaDataType(type_a), lda, b.opaque(),
1036 AsCudaDataType(type_b), ldb, beta, c->opaque(), AsCudaDataType(type_c),
1037 ldc, AsCublasComputeType(computation_type),
1038 static_cast<cublasGemmAlgo_t>(algorithm)));
1039 TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
1040 output_profile_result, stream));
1041 return ::tensorflow::OkStatus();
1042 }
1043
1044 port::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
1045 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1046 uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a,
1047 blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b,
1048 blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
1049 DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
1050 int batch_count, blas::ComputationType computation_type,
1051 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
1052 TF_ASSIGN_OR_RETURN(cublasMath_t math_type,
1053 GetMathTypeForGemmEx(stream, algorithm, type_a, type_b));
1054 TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
1055 stream, parent_, output_profile_result));
1056
1057 cudaDataType_t cuda_in_type = AsCudaDataType(type_a);
1058
1059 #if CUDA_VERSION >= 11000
1060 // Workaround CUDA bug where batched GEMM is erroneously marked as
1061 // unsupported by manually unbatching it on Pascal.
1062 if (cuda_in_type == CUDA_R_16BF &&
1063 !stream->GetCudaComputeCapability().IsAtLeast(7)) {
1064 for (int batch = 0; batch < batch_count; ++batch) {
1065 const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1066 static_cast<const Eigen::bfloat16 *>(a.opaque()) + batch * stride_a);
1067 const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1068 static_cast<const Eigen::bfloat16 *>(b.opaque()) + batch * stride_b);
1069 auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
1070 static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
1071 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1072 AS_LAMBDA(cublasGemmEx), stream, /*pointer_mode_host=*/true,
1073 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1074 k, static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
1075 b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
1076 c_matrix, CUDA_R_16BF, ldc, AsCublasComputeType(computation_type),
1077 static_cast<cublasGemmAlgo_t>(algorithm)));
1078 }
1079 TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
1080 output_profile_result, stream));
1081 return port::Status::OK();
1082 }
1083 #endif
1084
1085 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1086 AS_LAMBDA(cublasGemmStridedBatchedEx), stream, /*pointer_mode_host=*/true,
1087 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1088 alpha, a.opaque(), cuda_in_type, lda, stride_a, b.opaque(), cuda_in_type,
1089 ldb, stride_b, beta, c->opaque(), AsCudaDataType(type_c), ldc, stride_c,
1090 batch_count, AsCublasComputeType(computation_type),
1091 static_cast<cublasGemmAlgo_t>(algorithm)));
1092 TF_RETURN_IF_ERROR(PopulateProfileFromTimer(timer.get(), algorithm,
1093 output_profile_result, stream));
1094 return ::tensorflow::OkStatus();
1095 }
1096
1097 bool CUDABlas::GetBlasGemmAlgorithms(
1098 Stream *stream, std::vector<blas::AlgorithmType> *out_algorithms) {
1099 // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
1100 // were first introduced in CUDA 8.
1101 //
1102 // Note that when CUDA version and compute capability is not sufficient, we
1103 // still return the out_algorithms. Caller needs to make sure that in this
1104 // case, the returned vector is empty.
1105 if (stream->GetCudaComputeCapability().IsAtLeast(
1106 CudaComputeCapability::AMPERE)) {
1107 // Note: for NVIDIA Ampere Architecture GPUs and beyond, i.e. SM version >=
1108 // 80, the numbered algorithm options are equivalent to CUBLAS_GEMM_DEFAULT
1109 // or CUBLAS_GEMM_DEFAULT_TENSOR_OP respectively.
1110 *out_algorithms = {
1111 CUBLAS_GEMM_DFALT,
1112 CUBLAS_GEMM_DFALT_TENSOR_OP,
1113 };
1114 } else {
1115 *out_algorithms = {
1116 CUBLAS_GEMM_DFALT,
1117 CUBLAS_GEMM_ALGO0,
1118 CUBLAS_GEMM_ALGO1,
1119 CUBLAS_GEMM_ALGO2,
1120 CUBLAS_GEMM_ALGO3,
1121 CUBLAS_GEMM_ALGO4,
1122 CUBLAS_GEMM_ALGO5,
1123 CUBLAS_GEMM_ALGO6,
1124 CUBLAS_GEMM_ALGO7,
1125 #if CUDA_VERSION >= 9000
1126 CUBLAS_GEMM_ALGO8,
1127 CUBLAS_GEMM_ALGO9,
1128 CUBLAS_GEMM_ALGO10,
1129 CUBLAS_GEMM_ALGO11,
1130 CUBLAS_GEMM_ALGO12,
1131 CUBLAS_GEMM_ALGO13,
1132 CUBLAS_GEMM_ALGO14,
1133 CUBLAS_GEMM_ALGO15,
1134 CUBLAS_GEMM_ALGO16,
1135 CUBLAS_GEMM_ALGO17,
1136 CUBLAS_GEMM_DFALT_TENSOR_OP,
1137 CUBLAS_GEMM_ALGO0_TENSOR_OP,
1138 CUBLAS_GEMM_ALGO1_TENSOR_OP,
1139 CUBLAS_GEMM_ALGO2_TENSOR_OP,
1140 CUBLAS_GEMM_ALGO3_TENSOR_OP,
1141 CUBLAS_GEMM_ALGO4_TENSOR_OP,
1142 #endif
1143 #if CUDA_VERSION >= 9020
1144 CUBLAS_GEMM_ALGO18,
1145 CUBLAS_GEMM_ALGO19,
1146 CUBLAS_GEMM_ALGO20,
1147 CUBLAS_GEMM_ALGO21,
1148 CUBLAS_GEMM_ALGO22,
1149 CUBLAS_GEMM_ALGO23,
1150 CUBLAS_GEMM_ALGO5_TENSOR_OP,
1151 CUBLAS_GEMM_ALGO6_TENSOR_OP,
1152 CUBLAS_GEMM_ALGO7_TENSOR_OP,
1153 CUBLAS_GEMM_ALGO8_TENSOR_OP,
1154 CUBLAS_GEMM_ALGO9_TENSOR_OP,
1155 CUBLAS_GEMM_ALGO10_TENSOR_OP,
1156 CUBLAS_GEMM_ALGO11_TENSOR_OP,
1157 CUBLAS_GEMM_ALGO12_TENSOR_OP,
1158 CUBLAS_GEMM_ALGO13_TENSOR_OP,
1159 CUBLAS_GEMM_ALGO14_TENSOR_OP,
1160 CUBLAS_GEMM_ALGO15_TENSOR_OP,
1161 #endif
1162 };
1163 }
1164 return true;
1165 }
1166
1167 template <typename T>
1168 struct HalfAsFloat {
1169 typedef T type;
1170 };
1171
1172 template <>
1173 struct HalfAsFloat<Eigen::half> {
1174 typedef float type;
1175 };
1176
1177 namespace {
1178 // pass-through for non-complex types that don't need conversion to
1179 // cublas-specific type.
1180 template <typename T>
1181 T inline GpuComplexValue(T v) {
1182 return v;
1183 }
1184 } // namespace
1185
1186 template <typename T, typename Scalar, typename FuncT>
1187 port::Status CUDABlas::DoBlasGemmBatchedInternal(
1188 FuncT cublas_func, Stream *stream, blas::Transpose transa,
1189 blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
1190 const DeviceMemorySlice<T> &a_ptrs_to_wrappers, int lda,
1191 const DeviceMemorySlice<T> &b_ptrs_to_wrappers, int ldb, Scalar beta,
1192 const DeviceMemorySlice<T> &c_ptrs_to_wrappers, int ldc, int batch_count,
1193 ScratchAllocator *scratch_allocator) {
1194 std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
1195 for (int i = 0; i < batch_count; ++i) {
1196 a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
1197 b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
1198 c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
1199 }
1200
1201 typedef typename HalfAsFloat<typename GpuComplexT<T>::type>::type CUDA_T;
1202
1203 const size_t size = batch_count * sizeof(CUDA_T *);
1204
1205 // Device-side copy of pointers to matrices.
1206 DeviceMemory<CUDA_T *> a;
1207 DeviceMemory<CUDA_T *> b;
1208 DeviceMemory<CUDA_T *> c;
1209
1210 // If temporary space is allocated for device-side copies of pointers to
1211 // matrices, that temporary space should not be freed until this function
1212 // returns. Although the values for these unique_ptrs are not set here, they
1213 // are declared at this scope so they will be destroyed when the function
1214 // returns.
1215 //
1216 // If a scratch allocator is provided, these pointers will not be used at all.
1217 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
1218 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
1219 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
1220
1221 // Decide how to allocate device-side copy of pointers to matrices based on
1222 // whether a scratch allocator was passed.
1223 if (scratch_allocator != nullptr) {
1224 TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
1225 scratch_allocator->AllocateBytes(size));
1226 TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
1227 scratch_allocator->AllocateBytes(size));
1228 TF_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
1229 scratch_allocator->AllocateBytes(size));
1230 a = DeviceMemory<CUDA_T *>(a_bytes);
1231 b = DeviceMemory<CUDA_T *>(b_bytes);
1232 c = DeviceMemory<CUDA_T *>(c_bytes);
1233 } else {
1234 TF_ASSIGN_OR_RETURN(a_temporary,
1235 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
1236 TF_ASSIGN_OR_RETURN(b_temporary,
1237 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
1238 TF_ASSIGN_OR_RETURN(c_temporary,
1239 stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
1240 a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
1241 b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
1242 c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
1243 }
1244
1245 if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
1246 !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
1247 !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
1248 return port::Status(port::error::INTERNAL,
1249 "failed to copy memory from host to device in "
1250 "CUDABlas::DoBlasGemmBatched");
1251 }
1252
1253 cudaDataType_t data_type = CUDADataType<T>::type;
1254
1255 #if CUDA_VERSION >= 9010
1256 if (stream->GetCudaComputeCapability().IsAtLeast(5)) {
1257 cublasMath_t math_type;
1258 cublasGemmAlgo_t algo;
1259 if (data_type == CUDA_R_16F) {
1260 #if CUDA_VERSION < 11000
1261 math_type = CUBLAS_TENSOR_OP_MATH;
1262 #else
1263 math_type = CUBLAS_DEFAULT_MATH;
1264 #endif
1265 algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
1266 #if CUBLAS_VER_MAJOR >= 11
1267 } else if (data_type == CUDA_R_32F) {
1268 // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
1269 // if TensorFloat-32 is disabled.
1270 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1271 algo = tensorflow::tensor_float_32_execution_enabled()
1272 ? CUBLAS_GEMM_DFALT_TENSOR_OP
1273 : CUBLAS_GEMM_DFALT;
1274 #endif
1275 } else {
1276 math_type = CUBLAS_DEFAULT_MATH;
1277 algo = CUBLAS_GEMM_DFALT;
1278 }
1279 cudaDataType_t compute_type =
1280 (data_type == CUDA_R_16F ? CUDA_R_32F : data_type);
1281 const void **a_void_ptrs = reinterpret_cast<const void **>(
1282 const_cast<const CUDA_T **>(GpuMemory(a)));
1283 const void **b_void_ptrs = reinterpret_cast<const void **>(
1284 const_cast<const CUDA_T **>(GpuMemory(b)));
1285 void **c_void_ptrs =
1286 reinterpret_cast<void **>(const_cast<CUDA_T **>(GpuMemory(c)));
1287 return DoBlasInternalImpl(
1288 AS_LAMBDA(cublasGemmBatchedEx), stream, true /* = pointer_mode_host */,
1289 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1290 k, &alpha, a_void_ptrs, data_type, lda, b_void_ptrs, data_type, ldb,
1291 &beta, c_void_ptrs, data_type, ldc, batch_count, compute_type, algo);
1292 }
1293 #endif
1294 // either CUDA_VERSION < 9.1 or SM < 5.0
1295 if (data_type != CUDA_R_16F) {
1296 auto cb_alpha = GpuComplexValue(alpha);
1297 auto cb_beta = GpuComplexValue(beta);
1298 bool ok = DoBlasInternal(
1299 cublas_func, stream, true /* = pointer_mode_host */,
1300 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1301 GpuComplex(&cb_alpha), const_cast<const CUDA_T **>(GpuMemory(a)), lda,
1302 const_cast<const CUDA_T **>(GpuMemory(b)), ldb, GpuComplex(&cb_beta),
1303 const_cast<CUDA_T **>(GpuMemory(c)), ldc, batch_count);
1304 if (ok) {
1305 return ::tensorflow::OkStatus();
1306 }
1307 return port::Status(port::error::INTERNAL,
1308 "failed BLAS call, see log for details");
1309 } else {
1310 // Fall back to a loop for fp16
1311 for (int b = 0; b < batch_count; ++b) {
1312 const DeviceMemory<T> &a_matrix = *a_ptrs_to_wrappers[b];
1313 const DeviceMemory<T> &b_matrix = *b_ptrs_to_wrappers[b];
1314 DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
1315 TF_RETURN_IF_ERROR(DoBlasGemm(
1316 stream, transa, transb, m, n, k, blas::ToDataType<T>::value, &alpha,
1317 a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc,
1318 blas::kDefaultComputePrecision));
1319 }
1320 return ::tensorflow::OkStatus();
1321 }
1322 }
1323
1324 bool CUDABlas::DoBlasGemmBatched(
1325 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1326 uint64_t n, uint64 k, float alpha,
1327 const DeviceMemorySlice<Eigen::half> &a_array, int lda,
1328 const DeviceMemorySlice<Eigen::half> &b_array, int ldb, float beta,
1329 const DeviceMemorySlice<Eigen::half> &c_array, int ldc, int batch_count,
1330 ScratchAllocator *scratch_allocator) {
1331 // Note: The func passed here (cublasSgemmBatched) is not actually called,
1332 // due to special handling of fp16 inside DoBlasGemmBatchedInternal.
1333 port::Status status = DoBlasGemmBatchedInternal(
1334 cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1335 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1336 if (!status.ok()) {
1337 LOG(ERROR) << status;
1338 }
1339 return status.ok();
1340 }
1341
1342 bool CUDABlas::DoBlasGemmBatched(
1343 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1344 uint64_t n, uint64 k, float alpha, const DeviceMemorySlice<float> &a_array,
1345 int lda, const DeviceMemorySlice<float> &b_array, int ldb, float beta,
1346 const DeviceMemorySlice<float> &c_array, int ldc, int batch_count,
1347 ScratchAllocator *scratch_allocator) {
1348 port::Status status = DoBlasGemmBatchedInternal(
1349 cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1350 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1351 if (!status.ok()) {
1352 LOG(ERROR) << status;
1353 }
1354 return status.ok();
1355 }
1356
1357 bool CUDABlas::DoBlasGemmBatched(
1358 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1359 uint64_t n, uint64 k, double alpha,
1360 const DeviceMemorySlice<double> &a_array, int lda,
1361 const DeviceMemorySlice<double> &b_array, int ldb, double beta,
1362 const DeviceMemorySlice<double> &c_array, int ldc, int batch_count,
1363 ScratchAllocator *scratch_allocator) {
1364 port::Status status = DoBlasGemmBatchedInternal(
1365 cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1366 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1367 if (!status.ok()) {
1368 LOG(ERROR) << status;
1369 }
1370 return status.ok();
1371 }
1372
1373 bool CUDABlas::DoBlasGemmBatched(
1374 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1375 uint64_t n, uint64 k, std::complex<float> alpha,
1376 const DeviceMemorySlice<std::complex<float>> &a_array, int lda,
1377 const DeviceMemorySlice<std::complex<float>> &b_array, int ldb,
1378 std::complex<float> beta,
1379 const DeviceMemorySlice<std::complex<float>> &c_array, int ldc,
1380 int batch_count, ScratchAllocator *scratch_allocator) {
1381 port::Status status = DoBlasGemmBatchedInternal(
1382 cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1383 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1384 if (!status.ok()) {
1385 LOG(ERROR) << status;
1386 }
1387 return status.ok();
1388 }
1389
1390 bool CUDABlas::DoBlasGemmBatched(
1391 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1392 uint64_t n, uint64 k, std::complex<double> alpha,
1393 const DeviceMemorySlice<std::complex<double>> &a_array, int lda,
1394 const DeviceMemorySlice<std::complex<double>> &b_array, int ldb,
1395 std::complex<double> beta,
1396 const DeviceMemorySlice<std::complex<double>> &c_array, int ldc,
1397 int batch_count, ScratchAllocator *scratch_allocator) {
1398 port::Status status = DoBlasGemmBatchedInternal(
1399 cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
1400 b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
1401 if (!status.ok()) {
1402 LOG(ERROR) << status;
1403 }
1404 return status.ok();
1405 }
1406
1407 port::Status CUDABlas::DoBlasGemmStridedBatched(
1408 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m,
1409 uint64_t n, uint64 k, blas::DataType dtype, const void *alpha,
1410 const DeviceMemoryBase &a, int lda, int64_t stride_a,
1411 const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
1412 DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count) {
1413 cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
1414 #if CUDA_VERSION < 11000
1415 if (dtype == dnn::kHalf) {
1416 math_type = CUBLAS_TENSOR_OP_MATH;
1417 }
1418 #else
1419 if (dtype == dnn::kFloat) {
1420 math_type = CUBLAS_TF32_TENSOR_OP_MATH;
1421 }
1422 #endif
1423
1424 switch (dtype) {
1425 #if CUDA_VERSION >= 11000
1426 case dnn::kBF16: {
1427 CudaComputeCapability cc = stream->GetCudaComputeCapability();
1428 if (cc.IsAtLeast(7)) {
1429 cublasGemmAlgo_t algo =
1430 (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
1431 return DoBlasInternalImpl(
1432 AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
1433 true /* = pointer_mode_host */, math_type,
1434 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1435 alpha, a.opaque(), CUDA_R_16BF, lda, stride_a, b.opaque(),
1436 CUDA_R_16BF, ldb, stride_b, beta, c->opaque(), CUDA_R_16BF, ldc,
1437 stride_c, batch_count,
1438 /*compute_type=*/CUDA_R_32F, algo);
1439 }
1440 // Fall back to a loop.
1441 for (int batch = 0; batch < batch_count; ++batch) {
1442 const auto *a_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1443 static_cast<const Eigen::bfloat16 *>(a.opaque()) +
1444 batch * stride_a);
1445 const auto *b_matrix = reinterpret_cast<const __nv_bfloat16 *>(
1446 static_cast<const Eigen::bfloat16 *>(b.opaque()) +
1447 batch * stride_b);
1448 auto *c_matrix = reinterpret_cast<__nv_bfloat16 *>(
1449 static_cast<Eigen::bfloat16 *>(c->opaque()) + batch * stride_c);
1450 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1451 cublasSgemmEx, stream, true /* = pointer_mode_host */,
1452 CUBLAS_DEFAULT_MATH, AsCublasOperation(transa),
1453 AsCublasOperation(transb), m, n, k,
1454 static_cast<const float *>(alpha), a_matrix, CUDA_R_16BF, lda,
1455 b_matrix, CUDA_R_16BF, ldb, static_cast<const float *>(beta),
1456 c_matrix, CUDA_R_16BF, ldc));
1457 }
1458 return port::Status::OK();
1459 }
1460 #endif
1461 case dnn::kHalf: {
1462 #if CUDA_VERSION >= 9010
1463 CudaComputeCapability cc = stream->GetCudaComputeCapability();
1464 if (cc.major >= 5) {
1465 cublasGemmAlgo_t algo =
1466 (cc.major >= 7 ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
1467 return DoBlasInternalImpl(
1468 AS_LAMBDA(cublasGemmStridedBatchedEx), stream,
1469 true /* = pointer_mode_host */, math_type,
1470 AsCublasOperation(transa), AsCublasOperation(transb), m, n, k,
1471 alpha, a.opaque(), CUDA_R_16F, lda, stride_a, b.opaque(),
1472 CUDA_R_16F, ldb, stride_b, beta, c->opaque(), CUDA_R_16F, ldc,
1473 stride_c, batch_count, CUDA_R_32F, algo);
1474 }
1475 #endif
1476 // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop.
1477 for (int batch = 0; batch < batch_count; ++batch) {
1478 const auto *a_matrix = reinterpret_cast<const __half *>(
1479 static_cast<const Eigen::half *>(a.opaque()) + batch * stride_a);
1480 const auto *b_matrix = reinterpret_cast<const __half *>(
1481 static_cast<const Eigen::half *>(b.opaque()) + batch * stride_b);
1482 auto *c_matrix = reinterpret_cast<__half *>(
1483 static_cast<Eigen::half *>(c->opaque()) + batch * stride_c);
1484 TF_RETURN_IF_ERROR(DoBlasInternalImpl(
1485 cublasSgemmEx, stream, true /* = pointer_mode_host */,
1486 CUBLAS_DEFAULT_MATH, AsCublasOperation(transa),
1487 AsCublasOperation(transb), m, n, k,
1488 static_cast<const float *>(alpha), a_matrix, SE_CUDA_DATA_HALF, lda,
1489 b_matrix, SE_CUDA_DATA_HALF, ldb, static_cast<const float *>(beta),
1490 c_matrix, SE_CUDA_DATA_HALF, ldc));
1491 }
1492 return ::tensorflow::OkStatus();
1493 }
1494 case dnn::kFloat: {
1495 return DoBlasInternalImpl(
1496 cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */,
1497 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1498 k, static_cast<const float *>(alpha),
1499 static_cast<const float *>(a.opaque()), lda, stride_a,
1500 static_cast<const float *>(b.opaque()), ldb, stride_b,
1501 static_cast<const float *>(beta), static_cast<float *>(c->opaque()),
1502 ldc, stride_c, batch_count);
1503 }
1504 case dnn::kDouble:
1505 return DoBlasInternalImpl(
1506 cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */,
1507 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1508 k, static_cast<const double *>(alpha),
1509 static_cast<const double *>(a.opaque()), lda, stride_a,
1510 static_cast<const double *>(b.opaque()), ldb, stride_b,
1511 static_cast<const double *>(beta), static_cast<double *>(c->opaque()),
1512 ldc, stride_c, batch_count);
1513 case dnn::kComplexFloat: {
1514 GpuComplexType cb_alpha =
1515 GpuComplexValue(*static_cast<const std::complex<float> *>(alpha));
1516 GpuComplexType cb_beta =
1517 GpuComplexValue(*static_cast<const std::complex<float> *>(beta));
1518 return DoBlasInternalImpl(
1519 cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */,
1520 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1521 k, GpuComplex(&cb_alpha),
1522 static_cast<const GpuComplexType *>(a.opaque()), lda, stride_a,
1523 static_cast<const GpuComplexType *>(b.opaque()), ldb, stride_b,
1524 GpuComplex(&cb_beta), static_cast<GpuComplexType *>(c->opaque()), ldc,
1525 stride_c, batch_count);
1526 }
1527 case dnn::kComplexDouble: {
1528 GpuDoubleComplexType cb_alpha =
1529 GpuComplexValue(*static_cast<const std::complex<double> *>(alpha));
1530 GpuDoubleComplexType cb_beta =
1531 GpuComplexValue(*static_cast<const std::complex<double> *>(beta));
1532 return DoBlasInternalImpl(
1533 cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */,
1534 math_type, AsCublasOperation(transa), AsCublasOperation(transb), m, n,
1535 k, GpuComplex(&cb_alpha),
1536 static_cast<const GpuDoubleComplexType *>(a.opaque()), lda, stride_a,
1537 static_cast<const GpuDoubleComplexType *>(b.opaque()), ldb, stride_b,
1538 GpuComplex(&cb_beta),
1539 static_cast<GpuDoubleComplexType *>(c->opaque()), ldc, stride_c,
1540 batch_count);
1541 }
1542 default:
1543 return port::InternalError(absl::StrCat("Unsupported datatype for GEMM: ",
1544 blas::DataTypeString(dtype)));
1545 }
1546 }
1547
1548 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1549 blas::UpperLower uplo, blas::Transpose transa,
1550 blas::Diagonal diag, uint64_t m, uint64 n,
1551 float alpha, const DeviceMemory<float> &a, int lda,
1552 DeviceMemory<float> *b, int ldb) {
1553 return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */,
1554 CUDABlasSide(side), CUDABlasUpperLower(uplo),
1555 AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1556 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
1557 }
1558
1559 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1560 blas::UpperLower uplo, blas::Transpose transa,
1561 blas::Diagonal diag, uint64_t m, uint64 n,
1562 double alpha, const DeviceMemory<double> &a, int lda,
1563 DeviceMemory<double> *b, int ldb) {
1564 return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */,
1565 CUDABlasSide(side), CUDABlasUpperLower(uplo),
1566 AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1567 &alpha, GpuMemory(a), lda, GpuMemoryMutable(b), ldb);
1568 }
1569
1570 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1571 blas::UpperLower uplo, blas::Transpose transa,
1572 blas::Diagonal diag, uint64_t m, uint64 n,
1573 std::complex<float> alpha,
1574 const DeviceMemory<std::complex<float>> &a, int lda,
1575 DeviceMemory<std::complex<float>> *b, int ldb) {
1576 auto cb_alpha = GpuComplexValue(alpha);
1577 return DoBlasInternal(cublasCtrsm, stream, true /* = pointer_mode_host */,
1578 CUDABlasSide(side), CUDABlasUpperLower(uplo),
1579 AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1580 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1581 GpuComplex(GpuMemoryMutable(b)), ldb);
1582 }
1583
1584 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
1585 blas::UpperLower uplo, blas::Transpose transa,
1586 blas::Diagonal diag, uint64_t m, uint64 n,
1587 std::complex<double> alpha,
1588 const DeviceMemory<std::complex<double>> &a, int lda,
1589 DeviceMemory<std::complex<double>> *b, int ldb) {
1590 auto cb_alpha = GpuComplexValue(alpha);
1591 return DoBlasInternal(cublasZtrsm, stream, true /* = pointer_mode_host */,
1592 CUDABlasSide(side), CUDABlasUpperLower(uplo),
1593 AsCublasOperation(transa), CUDABlasDiagonal(diag), m, n,
1594 GpuComplex(&cb_alpha), GpuComplex(GpuMemory(a)), lda,
1595 GpuComplex(GpuMemoryMutable(b)), ldb);
1596 }
1597
1598 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1599 blas::UpperLower uplo, blas::Transpose transa,
1600 blas::Diagonal diag, uint64_t m, uint64 n,
1601 float alpha, const DeviceMemory<float *> &as,
1602 int lda, DeviceMemory<float *> *bs, int ldb,
1603 int batch_count) {
1604 return DoBlasInternal(cublasStrsmBatched, stream,
1605 true /* = pointer_mode_host */, CUDABlasSide(side),
1606 CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1607 CUDABlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1608 lda, GpuMemoryMutable(bs), ldb, batch_count);
1609 }
1610
1611 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1612 blas::UpperLower uplo, blas::Transpose transa,
1613 blas::Diagonal diag, uint64_t m, uint64 n,
1614 double alpha, const DeviceMemory<double *> &as,
1615 int lda, DeviceMemory<double *> *bs, int ldb,
1616 int batch_count) {
1617 return DoBlasInternal(cublasDtrsmBatched, stream,
1618 true /* = pointer_mode_host */, CUDABlasSide(side),
1619 CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1620 CUDABlasDiagonal(diag), m, n, &alpha, GpuMemory(as),
1621 lda, GpuMemoryMutable(bs), ldb, batch_count);
1622 }
1623
1624 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1625 blas::UpperLower uplo, blas::Transpose transa,
1626 blas::Diagonal diag, uint64_t m, uint64 n,
1627 std::complex<float> alpha,
1628 const DeviceMemory<std::complex<float> *> &as,
1629 int lda,
1630 DeviceMemory<std::complex<float> *> *bs,
1631 int ldb, int batch_count) {
1632 auto cb_alpha = GpuComplexValue(alpha);
1633 return DoBlasInternal(
1634 cublasCtrsmBatched, stream, true /* = pointer_mode_host */,
1635 CUDABlasSide(side), CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1636 CUDABlasDiagonal(diag), m, n, &cb_alpha,
1637 reinterpret_cast<float2 *const *>(GpuMemory(as)), lda,
1638 reinterpret_cast<float2 **>(GpuMemoryMutable(bs)), ldb, batch_count);
1639 }
1640
1641 bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side,
1642 blas::UpperLower uplo, blas::Transpose transa,
1643 blas::Diagonal diag, uint64_t m, uint64 n,
1644 std::complex<double> alpha,
1645 const DeviceMemory<std::complex<double> *> &as,
1646 int lda,
1647 DeviceMemory<std::complex<double> *> *bs,
1648 int ldb, int batch_count) {
1649 auto cb_alpha = GpuComplexValue(alpha);
1650 return DoBlasInternal(
1651 cublasZtrsmBatched, stream, true /* = pointer_mode_host */,
1652 CUDABlasSide(side), CUDABlasUpperLower(uplo), AsCublasOperation(transa),
1653 CUDABlasDiagonal(diag), m, n, &cb_alpha,
1654 reinterpret_cast<double2 *const *>(GpuMemory(as)), lda,
1655 reinterpret_cast<double2 **>(GpuMemoryMutable(bs)), ldb, batch_count);
1656 }
1657
1658 port::Status CUDABlas::GetVersion(std::string *version) {
1659 absl::MutexLock lock(&mu_);
1660
1661 int v;
1662 auto status = cublasGetVersion(blas_, &v);
1663 if (status != CUBLAS_STATUS_SUCCESS) {
1664 return port::InternalError(ToString(status));
1665 }
1666 *version = std::to_string(v);
1667 return ::tensorflow::OkStatus();
1668 }
1669
1670 void initialize_cublas() {
1671 port::Status status =
1672 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
1673 kCudaPlatformId, kCuBlasPlugin, "cuBLAS",
1674 [](::stream_executor::internal::StreamExecutorInterface *parent)
1675 -> blas::BlasSupport * {
1676 gpu::GpuExecutor *cuda_executor =
1677 dynamic_cast<gpu::GpuExecutor *>(parent);
1678 if (cuda_executor == nullptr) {
1679 LOG(ERROR)
1680 << "Attempting to initialize an instance of the cuBLAS "
1681 << "support library with a non-CUDA StreamExecutor";
1682 return nullptr;
1683 }
1684
1685 CUDABlas *blas = new CUDABlas(cuda_executor);
1686 if (!blas->Init()) {
1687 // Note: Init() will log a more specific error.
1688 delete blas;
1689 return nullptr;
1690 }
1691 return blas;
1692 });
1693
1694 if (!status.ok()) {
1695 LOG(ERROR) << "Unable to register cuBLAS factory: "
1696 << status.error_message();
1697 }
1698
1699 PluginRegistry::Instance()->SetDefaultFactory(
1700 cuda::kCudaPlatformId, PluginKind::kBlas, kCuBlasPlugin);
1701 }
1702
1703 } // namespace cuda
1704 } // namespace stream_executor
1705
1706 REGISTER_MODULE_INITIALIZER(register_cublas,
1707 { stream_executor::cuda::initialize_cublas(); });
1708