xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_
18 
19 #include <algorithm>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <vector>
24 
25 #include "third_party/gpus/cuda/include/cublasLt.h"
26 #include "third_party/gpus/cuda/include/cublas_v2.h"
27 #include "third_party/gpus/cuda/include/cuda.h"
28 #include "tensorflow/compiler/xla/stream_executor/blas.h"
29 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h"
30 #include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h"
31 #include "tensorflow/compiler/xla/stream_executor/lib/status.h"
32 
33 namespace stream_executor {
34 namespace gpu {
35 class GpuExecutor;
36 }  // namespace gpu
37 
38 namespace cuda {
39 
40 class BlasLt {
41   template <typename T>
42   using Owned =
43       std::unique_ptr<std::remove_pointer_t<T>, cublasStatus_t (*)(T)>;
44 
45  public:
46   class MatrixLayout {
47    public:
48     enum class Order { kRowMajor, kColumnMajor };
49 
50     // If `leading_dim_stride` is not specified, it defaults to:
51     //  - `num_cols` if `order == kRowMajor`,
52     //  - `num_rows` if `order == kColumnMajor`.
53     // If `batch_stride` is not specified, it defaults to `num_rows * num_cols`
54     // if `batch_size > 1`, otherwise `0`.
55     static port::StatusOr<MatrixLayout> Create(
56         blas::DataType type, size_t num_rows, size_t num_cols, Order order,
57         size_t batch_size = 1,
58         std::optional<int64_t> leading_dim_stride = std::nullopt,
59         std::optional<int64_t> batch_stride = std::nullopt);
60 
61     cudaDataType_t type() const;
62 
get()63     cublasLtMatrixLayout_t get() const { return handle_.get(); }
64 
65    private:
MatrixLayout(cublasLtMatrixLayout_t handle)66     explicit MatrixLayout(cublasLtMatrixLayout_t handle)
67         : handle_(handle, cublasLtMatrixLayoutDestroy) {}
68 
69     Owned<cublasLtMatrixLayout_t> handle_;
70   };
71 
72   enum class Epilogue {
73     kDefault = 1,                   // No special postprocessing
74     kReLU = 2,                      // Apply point-wise ReLU function
75     kBias = 4,                      // Add broadcasted bias vector
76     kBiasThenReLU = kBias | kReLU,  // Apply bias and then ReLU transform
77     kGeLU = 32,  // Apply GELU point-wise transform to the results
78     kBiasThenGeLUApproximate =
79         kBias | kGeLU,  // Apply bias and then GeLU Tanh transform
80   };
81 
82   // Describes the location of pointers for the scaling factors alpha and beta.
83   enum class PointerMode {
84     kHost,
85     kDevice,
86   };
87 
88   class MatmulDesc {
89    public:
90     static port::StatusOr<MatmulDesc> Create(
91         blas::ComputationType compute_type, blas::DataType scale_type,
92         blas::Transpose trans_a = blas::Transpose::kNoTranspose,
93         blas::Transpose trans_b = blas::Transpose::kNoTranspose,
94         Epilogue epilogue = Epilogue::kDefault,
95         PointerMode pointer_mode = PointerMode::kHost);
96 
97     cublasComputeType_t compute_type() const;
98     cudaDataType_t scale_type() const;
99     cublasLtPointerMode_t pointer_mode() const;
100 
get()101     cublasLtMatmulDesc_t get() const { return handle_.get(); }
102 
103    private:
MatmulDesc(cublasLtMatmulDesc_t handle)104     explicit MatmulDesc(cublasLtMatmulDesc_t handle)
105         : handle_(handle, cublasLtMatmulDescDestroy) {}
106 
107     Owned<cublasLtMatmulDesc_t> handle_;
108   };
109 
110   // TODO(cjfj): Add consistency checks for types, shapes, etc.?
111   struct MatmulPlan {
112     MatmulDesc op_desc;
113     MatrixLayout a_desc;
114     MatrixLayout b_desc;
115     MatrixLayout c_desc;
116     MatrixLayout d_desc;
117   };
118 
119   class MatmulPreference {
120    public:
121     static port::StatusOr<MatmulPreference> Create(size_t max_workspace_size);
122 
get()123     cublasLtMatmulPreference_t get() const { return handle_.get(); }
124 
125    private:
MatmulPreference(cublasLtMatmulPreference_t handle)126     explicit MatmulPreference(cublasLtMatmulPreference_t handle)
127         : handle_(handle, cublasLtMatmulPreferenceDestroy) {}
128 
129     Owned<cublasLtMatmulPreference_t> handle_;
130   };
131 
132   struct MatmulAlgorithm {
133     cublasLtMatmulAlgo_t algo;
134     size_t workspace_size;
135   };
136 
BlasLt(gpu::GpuExecutor * parent)137   explicit BlasLt(gpu::GpuExecutor* parent)
138       : parent_(parent), blas_lt_(nullptr, cublasLtDestroy) {}
139 
140   port::Status Init();
141 
142   // Returns the type for the alpha and beta scalars.
143   static blas::DataType GetScaleType(blas::DataType c_type,
144                                      blas::ComputationType computation_type);
145 
146   // Returns a list of supported algorithms for DoMatmul. The algorithms are
147   // returned in the order of increasing estimated compute time according to an
148   // internal heuristic.
149   port::StatusOr<std::vector<MatmulAlgorithm>> GetMatmulAlgorithms(
150       const MatmulPlan& plan, const MatmulPreference& preference,
151       size_t max_algorithm_count = 128);
152 
153   template <typename AB, typename CD, typename Scale>
154   port::Status DoMatmul(Stream* stream, const MatmulPlan& plan,
155                         const HostOrDeviceScalar<Scale>& alpha,
156                         const DeviceMemory<AB>& a, const DeviceMemory<AB>& b,
157                         const HostOrDeviceScalar<Scale>& beta,
158                         const DeviceMemory<CD>& c, DeviceMemory<CD>& d,
159                         const MatmulAlgorithm& algorithm,
160                         ScratchAllocator& scratch_allocator,
161                         const DeviceMemory<CD>& bias = {},
162                         blas::ProfileResult* profile_result = nullptr) {
163     if (AsCudaDataType(blas::ToDataType<Scale>::value) !=
164         plan.op_desc.scale_type()) {
165       return port::InvalidArgumentError("mismatched scale types");
166     }
167 
168     bool expect_scale_factor_on_device =
169         (plan.op_desc.pointer_mode() == CUBLASLT_POINTER_MODE_DEVICE);
170 
171     if (alpha.on_device() != expect_scale_factor_on_device) {
172       return port::InvalidArgumentError("wrong location for alpha");
173     }
174 
175     if (beta.on_device() != expect_scale_factor_on_device) {
176       return port::InvalidArgumentError("wrong location for beta");
177     }
178 
179     if (AsCudaDataType(blas::ToDataType<AB>::value) != plan.a_desc.type()) {
180       return port::InvalidArgumentError("mismatched A matrix types");
181     }
182 
183     if (AsCudaDataType(blas::ToDataType<AB>::value) != plan.b_desc.type()) {
184       return port::InvalidArgumentError("mismatched B matrix types");
185     }
186 
187     if (AsCudaDataType(blas::ToDataType<CD>::value) != plan.c_desc.type()) {
188       return port::InvalidArgumentError("mismatched C matrix types");
189     }
190 
191     if (AsCudaDataType(blas::ToDataType<CD>::value) != plan.d_desc.type()) {
192       return port::InvalidArgumentError("mismatched D matrix types");
193     }
194 
195     return DoMatmul(stream, plan, alpha.opaque(), a, b, beta.opaque(), c, d,
196                     algorithm, scratch_allocator, bias, profile_result);
197   }
198 
199  private:
200   port::Status DoMatmul(Stream* stream, const MatmulPlan& plan,
201                         const void* alpha, DeviceMemoryBase a,
202                         DeviceMemoryBase b, const void* beta,
203                         DeviceMemoryBase c, DeviceMemoryBase d,
204                         const MatmulAlgorithm& algorithm,
205                         ScratchAllocator& scratch_allocator,
206                         DeviceMemoryBase bias,
207                         blas::ProfileResult* profile_result);
208 
209   gpu::GpuExecutor* parent_;
210 
211   absl::Mutex mu_;
212   Owned<cublasLtHandle_t> blas_lt_ ABSL_GUARDED_BY(mu_);
213 };
214 
215 // Returns `BlasLt` implementation for a stream if available, or `nullptr`.
216 BlasLt* GetBlasLt(Stream* stream);
217 
218 }  // namespace cuda
219 }  // namespace stream_executor
220 
221 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_
222