xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
17 // capabilities, and is only included into CUDA implementation code -- it will
18 // not introduce cuda headers into other code.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
21 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
22 
23 #include "absl/base/thread_annotations.h"
24 #include "absl/synchronization/mutex.h"
25 #include "third_party/gpus/cuda/include/cublas_v2.h"
26 #include "tensorflow/compiler/xla/stream_executor/blas.h"
27 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h"
28 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
29 #include "tensorflow/compiler/xla/stream_executor/plugin_registry.h"
30 
31 namespace stream_executor {
32 
33 class Stream;
34 
35 namespace gpu {
36 class GpuExecutor;
37 }  // namespace gpu
38 
39 namespace cuda {
40 
41 // Opaque and unique identifier for the cuBLAS plugin.
42 extern const PluginId kCuBlasPlugin;
43 
44 template <typename T>
45 using DeviceMemorySlice = port::ArraySlice<DeviceMemory<T> *>;  // non-absl ok
46 
47 // BLAS plugin for CUDA platform via cuBLAS library.
48 //
49 // This satisfies the platform-agnostic BlasSupport interface.
50 //
51 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the
52 // context (and, as a result, the device) that the parent GpuExecutor is tied
53 // to. This simply happens as an artifact of creating the cuBLAS handle when a
54 // CUDA context is active.
55 //
56 // Thread-safe post-initialization.
57 class CUDABlas : public blas::BlasSupport {
58  public:
59   explicit CUDABlas(gpu::GpuExecutor *parent);
60 
61   // Allocates a cuBLAS handle.
62   bool Init();
63 
64   // Releases the cuBLAS handle, if present.
65   ~CUDABlas() override;
66 
67   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
68 
blas_lt()69   BlasLt &blas_lt() { return blas_lt_; }
70 
71  private:
72   // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
73   //
74   // cuBLAS is stateful, and only be associated with one stream (in order to
75   // enqueue dispatch) at a given time. As a result, this generally must be
76   // invoked before calling into cuBLAS.
77   bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
78 
79   // Returns the underlying CUDA stream.
80   cudaStream_t CUDAStream(Stream *stream);
81 
82   // A helper function that calls the real cuBLAS function together with error
83   // handling.
84   //
85   // cublas_func:        cuBLAS function pointer.
86   // cublas_name:        cuBLAS function name.
87   // stream:             Stream to enqueue the BLAS operation onto.
88   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
89   //                     (true) or device (false).
90   // args:               Arguments of cuBLAS function.
91   template <typename FuncT, typename... Args>
92   port::Status DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
93                                   bool pointer_mode_host,
94                                   cublasMath_t math_type, Args... args);
95 
96   // Convenience functions that call DoBlasInternalImpl with err_on_failure=true
97   // and math_type=CUBLAS_DEFAULT_MATH.
98   template <typename FuncT, typename... Args>
DoBlasInternal(FuncT cublas_func,Stream * stream,bool pointer_mode_host,Args...args)99   bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
100                       Args... args) {
101     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
102                               CUBLAS_DEFAULT_MATH, args...)
103         .ok();
104   }
105 
106   // A helper function to implement DoBlasGemmBatched interfaces for generic
107   // types.
108   template <typename T, typename Scalar, typename FuncT>
109   port::Status DoBlasGemmBatchedInternal(
110       FuncT cublas_func, Stream *stream, blas::Transpose transa,
111       blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha,
112       const DeviceMemorySlice<T> &a_array, int lda,
113       const DeviceMemorySlice<T> &b_array, int ldb, Scalar beta,
114       const DeviceMemorySlice<T> &c_array, int ldc, int batch_count,
115       ScratchAllocator *scratch_allocator);
116 
117   // Helper function for implementing DoBlasGemmWithProfiling.
118   template <typename T, typename ParamType>
119   bool DoBlasGemmWithProfilingImpl(Stream *stream, blas::Transpose transa,
120                                    blas::Transpose transb, uint64_t m,
121                                    uint64_t n, uint64 k, const ParamType &alpha,
122                                    const DeviceMemory<T> &a, int lda,
123                                    const DeviceMemory<T> &b, int ldb,
124                                    const ParamType &beta, DeviceMemory<T> *c,
125                                    int ldc,
126                                    blas::ProfileResult *output_profile_result);
127 
128   // Helper function for implementing DoBlasGemvWithProfiling.
129   template <typename T>
130   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
131                                    uint64_t m, uint64 n, const T &alpha,
132                                    const DeviceMemory<T> &a, int lda,
133                                    const DeviceMemory<T> &x, int incx,
134                                    const T &beta, DeviceMemory<T> *y, int incy,
135                                    blas::ProfileResult *output_profile_result);
136 
137   // Guards the cuBLAS handle for this device.
138   absl::Mutex mu_;
139 
140   // GpuExecutor which instantiated this CUDABlas.
141   // Immutable post-initialization.
142   gpu::GpuExecutor *parent_;
143 
144   // cuBLAS library handle on the device.
145   cublasHandle_t blas_ ABSL_GUARDED_BY(mu_);
146 
147   BlasLt blas_lt_;
148 
149   SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
150 };
151 
152 }  // namespace cuda
153 }  // namespace stream_executor
154 
155 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
156