1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library 17 // capabilities, and is only included into CUDA implementation code -- it will 18 // not introduce cuda headers into other code. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 21 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 22 23 #include "absl/base/thread_annotations.h" 24 #include "absl/synchronization/mutex.h" 25 #include "third_party/gpus/cuda/include/cublas_v2.h" 26 #include "tensorflow/compiler/xla/stream_executor/blas.h" 27 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" 28 #include "tensorflow/compiler/xla/stream_executor/platform/port.h" 29 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h" 30 31 namespace stream_executor { 32 33 class Stream; 34 35 namespace gpu { 36 class GpuExecutor; 37 } // namespace gpu 38 39 namespace cuda { 40 41 // Opaque and unique identifier for the cuBLAS plugin. 42 extern const PluginId kCuBlasPlugin; 43 44 template <typename T> 45 using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>; // non-absl ok 46 47 // BLAS plugin for CUDA platform via cuBLAS library. 48 // 49 // This satisfies the platform-agnostic BlasSupport interface. 50 // 51 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the 52 // context (and, as a result, the device) that the parent GpuExecutor is tied 53 // to. This simply happens as an artifact of creating the cuBLAS handle when a 54 // CUDA context is active. 55 // 56 // Thread-safe post-initialization. 57 class CUDABlas : public blas::BlasSupport { 58 public: 59 explicit CUDABlas(gpu::GpuExecutor *parent); 60 61 // Allocates a cuBLAS handle. 62 bool Init(); 63 64 // Releases the cuBLAS handle, if present. 65 ~CUDABlas() override; 66 67 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 68 blas_lt()69 BlasLt &blas_lt() { return blas_lt_; } 70 71 private: 72 // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. 73 // 74 // cuBLAS is stateful, and only be associated with one stream (in order to 75 // enqueue dispatch) at a given time. As a result, this generally must be 76 // invoked before calling into cuBLAS. 77 bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); 78 79 // Returns the underlying CUDA stream. 80 cudaStream_t CUDAStream(Stream *stream); 81 82 // A helper function that calls the real cuBLAS function together with error 83 // handling. 84 // 85 // cublas_func: cuBLAS function pointer. 86 // cublas_name: cuBLAS function name. 87 // stream: Stream to enqueue the BLAS operation onto. 88 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 89 // (true) or device (false). 90 // args: Arguments of cuBLAS function. 91 template <typename FuncT, typename... Args> 92 port::Status DoBlasInternalImpl(FuncT cublas_func, Stream *stream, 93 bool pointer_mode_host, 94 cublasMath_t math_type, Args... args); 95 96 // Convenience functions that call DoBlasInternalImpl with err_on_failure=true 97 // and math_type=CUBLAS_DEFAULT_MATH. 98 template <typename FuncT, typename... Args> DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)99 bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, 100 Args... args) { 101 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 102 CUBLAS_DEFAULT_MATH, args...) 103 .ok(); 104 } 105 106 // A helper function to implement DoBlasGemmBatched interfaces for generic 107 // types. 108 template <typename T, typename Scalar, typename FuncT> 109 port::Status DoBlasGemmBatchedInternal( 110 FuncT cublas_func, Stream *stream, blas::Transpose transa, 111 blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha, 112 const DeviceMemorySlice<T> &a_array, int lda, 113 const DeviceMemorySlice<T> &b_array, int ldb, Scalar beta, 114 const DeviceMemorySlice<T> &c_array, int ldc, int batch_count, 115 ScratchAllocator *scratch_allocator); 116 117 // Helper function for implementing DoBlasGemmWithProfiling. 118 template <typename T, typename ParamType> 119 bool DoBlasGemmWithProfilingImpl(Stream *stream, blas::Transpose transa, 120 blas::Transpose transb, uint64_t m, 121 uint64_t n, uint64 k, const ParamType &alpha, 122 const DeviceMemory<T> &a, int lda, 123 const DeviceMemory<T> &b, int ldb, 124 const ParamType &beta, DeviceMemory<T> *c, 125 int ldc, 126 blas::ProfileResult *output_profile_result); 127 128 // Helper function for implementing DoBlasGemvWithProfiling. 129 template <typename T> 130 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 131 uint64_t m, uint64 n, const T &alpha, 132 const DeviceMemory<T> &a, int lda, 133 const DeviceMemory<T> &x, int incx, 134 const T &beta, DeviceMemory<T> *y, int incy, 135 blas::ProfileResult *output_profile_result); 136 137 // Guards the cuBLAS handle for this device. 138 absl::Mutex mu_; 139 140 // GpuExecutor which instantiated this CUDABlas. 141 // Immutable post-initialization. 142 gpu::GpuExecutor *parent_; 143 144 // cuBLAS library handle on the device. 145 cublasHandle_t blas_ ABSL_GUARDED_BY(mu_); 146 147 BlasLt blas_lt_; 148 149 SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); 150 }; 151 152 } // namespace cuda 153 } // namespace stream_executor 154 155 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 156