1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h" 17 18 #include "absl/strings/str_cat.h" 19 #include "third_party/gpus/cuda/include/cublas_v2.h" 20 #include "third_party/gpus/cuda/include/cuda.h" 21 #include "tensorflow/compiler/xla/stream_executor/blas.h" 22 23 namespace stream_executor { 24 namespace cuda { 25 ToString(cublasStatus_t status)26const char* ToString(cublasStatus_t status) { 27 #if CUDA_VERSION >= 11050 // `GetStatusString` was added in 11.4 update 2. 28 return cublasGetStatusString(status); 29 #else 30 return "cublas error"; 31 #endif // CUDA_VERSION >= 11050 32 } 33 ToStatus(cublasStatus_t status,const char * prefix)34port::Status ToStatus(cublasStatus_t status, const char* prefix) { 35 if (status != CUBLAS_STATUS_SUCCESS) { 36 return port::Status(port::error::INTERNAL, 37 absl::StrCat(prefix, ": ", ToString(status))); 38 } 39 return port::Status::OK(); 40 } 41 AsCudaDataType(blas::DataType type)42cudaDataType_t AsCudaDataType(blas::DataType type) { 43 switch (type) { 44 case blas::DataType::kHalf: 45 return CUDA_R_16F; 46 case blas::DataType::kBF16: 47 return CUDA_R_16BF; 48 case blas::DataType::kFloat: 49 return CUDA_R_32F; 50 case blas::DataType::kDouble: 51 return CUDA_R_64F; 52 case blas::DataType::kInt8: 53 return CUDA_R_8I; 54 case blas::DataType::kInt32: 55 return CUDA_R_32I; 56 case blas::DataType::kComplexFloat: 57 return CUDA_C_32F; 58 case blas::DataType::kComplexDouble: 59 return CUDA_C_64F; 60 default: 61 LOG(FATAL) << "unknown data type"; 62 } 63 } 64 AsCublasComputeType(blas::ComputationType type)65cublasComputeType_t AsCublasComputeType(blas::ComputationType type) { 66 switch (type) { 67 case blas::ComputationType::kF16: 68 return CUBLAS_COMPUTE_16F; 69 case blas::ComputationType::kF32: 70 return CUBLAS_COMPUTE_32F; 71 case blas::ComputationType::kF64: 72 return CUBLAS_COMPUTE_64F; 73 case blas::ComputationType::kI32: 74 return CUBLAS_COMPUTE_32I; 75 case blas::ComputationType::kF16AsF32: 76 return CUBLAS_COMPUTE_32F_FAST_16F; 77 case blas::ComputationType::kBF16AsF32: 78 return CUBLAS_COMPUTE_32F_FAST_16BF; 79 case blas::ComputationType::kTF32AsF32: 80 return CUBLAS_COMPUTE_32F_FAST_TF32; 81 } 82 } 83 AsCublasOperation(blas::Transpose trans)84cublasOperation_t AsCublasOperation(blas::Transpose trans) { 85 switch (trans) { 86 case blas::Transpose::kNoTranspose: 87 return CUBLAS_OP_N; 88 case blas::Transpose::kTranspose: 89 return CUBLAS_OP_T; 90 case blas::Transpose::kConjugateTranspose: 91 return CUBLAS_OP_C; 92 } 93 } 94 95 } // namespace cuda 96 } // namespace stream_executor 97