xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h"
17 
18 #include <algorithm>
19 #include <climits>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "third_party/gpus/cuda/include/cublasLt.h"
27 #include "third_party/gpus/cuda/include/cublas_v2.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/stream_executor/blas.h"
30 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h"
31 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h"
32 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_activation.h"
33 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h"
34 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
35 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h"
36 #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h"
37 #include "tensorflow/compiler/xla/stream_executor/stream.h"
38 
39 #define SET_ATTR(setter, handle, attr, value) \
40   ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter)
41 
42 #define GET_ATTR(getter, handle, attr, ValueT)                            \
43   [&]() -> port::StatusOr<ValueT> {                                       \
44     ValueT value;                                                         \
45     TF_RETURN_IF_ERROR(ToStatus(                                          \
46         getter(handle, attr, &value, sizeof(ValueT), nullptr), #getter)); \
47     return std::move(value);                                              \
48   }()
49 
50 namespace stream_executor {
51 namespace cuda {
52 namespace {
53 
54 template <typename T>
SetAttr(cublasLtMatrixLayout_t handle,cublasLtMatrixLayoutAttribute_t attr,T value)55 port::Status SetAttr(cublasLtMatrixLayout_t handle,
56                      cublasLtMatrixLayoutAttribute_t attr, T value) {
57   return SET_ATTR(cublasLtMatrixLayoutSetAttribute, handle, attr, value);
58 }
59 
60 template <typename T>
GetAttr(cublasLtMatrixLayout_t handle,cublasLtMatrixLayoutAttribute_t attr)61 port::StatusOr<T> GetAttr(cublasLtMatrixLayout_t handle,
62                           cublasLtMatrixLayoutAttribute_t attr) {
63   return GET_ATTR(cublasLtMatrixLayoutGetAttribute, handle, attr, T);
64 }
65 
66 template <typename T>
SetAttr(cublasLtMatmulDesc_t handle,cublasLtMatmulDescAttributes_t attr,T value)67 port::Status SetAttr(cublasLtMatmulDesc_t handle,
68                      cublasLtMatmulDescAttributes_t attr, T value) {
69   return SET_ATTR(cublasLtMatmulDescSetAttribute, handle, attr, value);
70 }
71 
72 template <typename T>
GetAttr(cublasLtMatmulDesc_t handle,cublasLtMatmulDescAttributes_t attr)73 port::StatusOr<T> GetAttr(cublasLtMatmulDesc_t handle,
74                           cublasLtMatmulDescAttributes_t attr) {
75   return GET_ATTR(cublasLtMatmulDescGetAttribute, handle, attr, T);
76 }
77 
78 template <typename T>
SetAttr(cublasLtMatmulPreference_t handle,cublasLtMatmulPreferenceAttributes_t attr,T value)79 port::Status SetAttr(cublasLtMatmulPreference_t handle,
80                      cublasLtMatmulPreferenceAttributes_t attr, T value) {
81   return SET_ATTR(cublasLtMatmulPreferenceSetAttribute, handle, attr, value);
82 }
83 
AsCublasLtPointerMode(BlasLt::PointerMode pointer_mode)84 cublasLtPointerMode_t AsCublasLtPointerMode(BlasLt::PointerMode pointer_mode) {
85   switch (pointer_mode) {
86     case BlasLt::PointerMode::kHost:
87       return CUBLASLT_POINTER_MODE_HOST;
88     case BlasLt::PointerMode::kDevice:
89       return CUBLASLT_POINTER_MODE_DEVICE;
90   }
91 }
92 
AsCublasLtEpilogue(BlasLt::Epilogue epilogue)93 port::StatusOr<cublasLtEpilogue_t> AsCublasLtEpilogue(
94     BlasLt::Epilogue epilogue) {
95   switch (epilogue) {
96     case BlasLt::Epilogue::kDefault:
97       return CUBLASLT_EPILOGUE_DEFAULT;
98     case BlasLt::Epilogue::kReLU:
99       return CUBLASLT_EPILOGUE_RELU;
100     case BlasLt::Epilogue::kBias:
101       return CUBLASLT_EPILOGUE_BIAS;
102     case BlasLt::Epilogue::kBiasThenReLU:
103       return CUBLASLT_EPILOGUE_RELU_BIAS;
104     case BlasLt::Epilogue::kGeLU:
105 #if CUDA_VERSION >= 11040
106       return CUBLASLT_EPILOGUE_GELU;
107 #else
108       return port::InternalError(absl::StrCat(
109           "CUBLASLT_EPILOGUE_GELU epilog requires cublasLt >= 11.4"));
110 #endif
111     case BlasLt::Epilogue::kBiasThenGeLUApproximate:
112 #if CUDA_VERSION >= 11040
113       return CUBLASLT_EPILOGUE_GELU_BIAS;
114 #else
115       return port::InternalError(absl::StrCat(
116           "CUBLASLT_EPILOGUE_GELU_BIAS epilog requires cublasLt >= 11.4"));
117 #endif
118   }
119 }
120 
121 }  // namespace
122 
Init()123 port::Status BlasLt::Init() {
124   cublasLtHandle_t blas_lt;
125   SE_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&blas_lt));
126   absl::MutexLock lock(&mu_);
127   blas_lt_.reset(blas_lt);
128   return port::Status::OK();
129 }
130 
GetScaleType(blas::DataType c_type,blas::ComputationType computation_type)131 /*static*/ blas::DataType BlasLt::GetScaleType(
132     blas::DataType c_type, blas::ComputationType computation_type) {
133   return ((computation_type == blas::ComputationType::kF32) &&
134           (c_type != blas::DataType::kComplexFloat))
135              ? blas::DataType::kFloat
136              : c_type;
137 }
138 
Create(blas::DataType type,size_t num_rows,size_t num_cols,BlasLt::MatrixLayout::Order order,size_t batch_size,std::optional<int64_t> leading_dim_stride,std::optional<int64_t> batch_stride)139 /*static*/ port::StatusOr<BlasLt::MatrixLayout> BlasLt::MatrixLayout::Create(
140     blas::DataType type, size_t num_rows, size_t num_cols,
141     BlasLt::MatrixLayout::Order order, size_t batch_size,
142     std::optional<int64_t> leading_dim_stride,
143     std::optional<int64_t> batch_stride) {
144   if (!leading_dim_stride) {
145     leading_dim_stride = (order == Order::kRowMajor) ? num_cols : num_rows;
146   }
147 
148   cublasLtMatrixLayout_t cu_layout;
149   SE_CUBLAS_RETURN_IF_ERROR(
150       cublasLtMatrixLayoutCreate(&cu_layout, AsCudaDataType(type), num_rows,
151                                  num_cols, *leading_dim_stride));
152   // Wrap cublas handle immediately, so it is cleaned up if an error occurs.
153   BlasLt::MatrixLayout layout(cu_layout);
154   TF_RETURN_IF_ERROR(
155       SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_ORDER,
156               int32_t{(order == Order::kRowMajor) ? CUBLASLT_ORDER_ROW
157                                                   : CUBLASLT_ORDER_COL}));
158   TF_RETURN_IF_ERROR(SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
159                              static_cast<int32_t>(batch_size)));
160 
161   if (!batch_stride) {
162     batch_stride = (batch_size > 1) ? num_rows * num_cols : 0;
163   }
164 
165   TF_RETURN_IF_ERROR(SetAttr(
166       cu_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, *batch_stride));
167   return std::move(layout);
168 }
169 
type() const170 cudaDataType_t BlasLt::MatrixLayout::type() const {
171   return static_cast<cudaDataType_t>(
172       GetAttr<uint32_t>(handle_.get(), CUBLASLT_MATRIX_LAYOUT_TYPE)
173           .ValueOrDie());
174 }
175 
Create(blas::ComputationType compute_type,blas::DataType scale_type,blas::Transpose trans_a,blas::Transpose trans_b,BlasLt::Epilogue epilogue,BlasLt::PointerMode pointer_mode)176 /*static*/ port::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
177     blas::ComputationType compute_type, blas::DataType scale_type,
178     blas::Transpose trans_a, blas::Transpose trans_b, BlasLt::Epilogue epilogue,
179     BlasLt::PointerMode pointer_mode) {
180   cublasLtMatmulDesc_t cu_desc;
181   SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescCreate(
182       &cu_desc, AsCublasComputeType(compute_type), AsCudaDataType(scale_type)));
183   // Wrap cublas handle immediately, so it is cleaned up if an error occurs.
184   BlasLt::MatmulDesc desc(cu_desc);
185   TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
186                              AsCublasLtPointerMode(pointer_mode)));
187   TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSA,
188                              AsCublasOperation(trans_a)));
189   TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSB,
190                              AsCublasOperation(trans_b)));
191   TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
192   TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
193   return std::move(desc);
194 }
195 
compute_type() const196 cublasComputeType_t BlasLt::MatmulDesc::compute_type() const {
197   return static_cast<cublasComputeType_t>(
198       GetAttr<int32_t>(handle_.get(), CUBLASLT_MATMUL_DESC_COMPUTE_TYPE)
199           .ValueOrDie());
200 }
201 
scale_type() const202 cudaDataType_t BlasLt::MatmulDesc::scale_type() const {
203   return static_cast<cudaDataType_t>(
204       GetAttr<int32_t>(handle_.get(), CUBLASLT_MATMUL_DESC_SCALE_TYPE)
205           .ValueOrDie());
206 }
207 
pointer_mode() const208 cublasLtPointerMode_t BlasLt::MatmulDesc::pointer_mode() const {
209   return static_cast<cublasLtPointerMode_t>(
210       GetAttr<int32_t>(handle_.get(), CUBLASLT_MATMUL_DESC_POINTER_MODE)
211           .ValueOrDie());
212 }
213 
214 /*static*/ port::StatusOr<BlasLt::MatmulPreference>
Create(size_t max_workspace_size)215 BlasLt::MatmulPreference::Create(size_t max_workspace_size) {
216   cublasLtMatmulPreference_t cu_preference;
217   SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceCreate(&cu_preference));
218   // Wrap cublas handle immediately, so it is cleaned up if an error occurs.
219   BlasLt::MatmulPreference preference(cu_preference);
220   TF_RETURN_IF_ERROR(SetAttr<uint64_t>(cu_preference,
221                                        CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
222                                        max_workspace_size));
223   return std::move(preference);
224 }
225 
226 port::StatusOr<std::vector<BlasLt::MatmulAlgorithm>>
GetMatmulAlgorithms(const BlasLt::MatmulPlan & plan,const BlasLt::MatmulPreference & preference,size_t max_algorithm_count)227 BlasLt::GetMatmulAlgorithms(const BlasLt::MatmulPlan& plan,
228                             const BlasLt::MatmulPreference& preference,
229                             size_t max_algorithm_count) {
230   max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX});
231   std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
232   {
233     absl::MutexLock lock(&mu_);
234     TF_RET_CHECK(blas_lt_ != nullptr);
235 
236     gpu::ScopedActivateExecutorContext sac{parent_};
237 
238     int found_algorithm_count = 0;
239     SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulAlgoGetHeuristic(
240         blas_lt_.get(), plan.op_desc.get(), plan.a_desc.get(),
241         plan.b_desc.get(), plan.c_desc.get(), plan.d_desc.get(),
242         preference.get(), max_algorithm_count, results.data(),
243         &found_algorithm_count));
244     results.resize(found_algorithm_count);
245   }
246 
247   std::vector<BlasLt::MatmulAlgorithm> algorithms;
248   algorithms.reserve(results.size());
249   for (const cublasLtMatmulHeuristicResult_t& result : results) {
250     if (result.state == CUBLAS_STATUS_SUCCESS) {  // Skip failed algos.
251       algorithms.push_back({result.algo, result.workspaceSize});
252     }
253   }
254   return std::move(algorithms);
255 }
256 
DoMatmul(Stream * stream,const BlasLt::MatmulPlan & plan,const void * alpha,DeviceMemoryBase a,DeviceMemoryBase b,const void * beta,DeviceMemoryBase c,DeviceMemoryBase d,const BlasLt::MatmulAlgorithm & algorithm,ScratchAllocator & scratch_allocator,DeviceMemoryBase bias,blas::ProfileResult * profile_result)257 port::Status BlasLt::DoMatmul(Stream* stream, const BlasLt::MatmulPlan& plan,
258                               const void* alpha, DeviceMemoryBase a,
259                               DeviceMemoryBase b, const void* beta,
260                               DeviceMemoryBase c, DeviceMemoryBase d,
261                               const BlasLt::MatmulAlgorithm& algorithm,
262                               ScratchAllocator& scratch_allocator,
263                               DeviceMemoryBase bias,
264                               blas::ProfileResult* profile_result) {
265   std::unique_ptr<gpu::GpuTimer, gpu::GpuTimerDeleter> timer;
266   if (profile_result != nullptr) {
267     timer.reset(new gpu::GpuTimer(parent_));
268     TF_RET_CHECK(timer->Init());
269     TF_RET_CHECK(timer->Start(gpu::AsGpuStream(stream)));
270   }
271 
272   void* workspace = nullptr;
273   if (algorithm.workspace_size > 0) {
274     TF_ASSIGN_OR_RETURN(
275         DeviceMemory<uint8_t> alloc,
276         scratch_allocator.AllocateBytes(algorithm.workspace_size));
277     workspace = gpu::GpuMemoryMutable(&alloc);
278   }
279 
280   {
281     absl::MutexLock lock(&mu_);
282     TF_RET_CHECK(blas_lt_ != nullptr);
283     // We must set the bias pointer while holding the mutex, to avoid a
284     // potential race condition from multiple threads sharing the same plan.
285     if (bias != nullptr) {
286       TF_RETURN_IF_ERROR(SetAttr(plan.op_desc.get(),
287                                  CUBLASLT_MATMUL_DESC_BIAS_POINTER,
288                                  bias.opaque()));
289     }
290 
291     gpu::ScopedActivateExecutorContext sac{parent_};
292 
293     SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmul(
294         blas_lt_.get(), plan.op_desc.get(), alpha, a.opaque(),
295         plan.a_desc.get(), b.opaque(), plan.b_desc.get(), beta, c.opaque(),
296         plan.c_desc.get(), d.opaque(), plan.d_desc.get(), &algorithm.algo,
297         workspace, algorithm.workspace_size, gpu::AsGpuStreamValue(stream)));
298   }
299 
300   if (timer) {
301     TF_RET_CHECK(timer->Stop(gpu::AsGpuStream(stream)));
302     profile_result->set_is_valid(true);
303     profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
304   }
305   return port::Status::OK();
306 }
307 
GetBlasLt(Stream * stream)308 BlasLt* GetBlasLt(Stream* stream) {
309   CUDABlas* blas = dynamic_cast<CUDABlas*>(stream->parent()->AsBlas());
310   return (blas != nullptr) ? &blas->blas_lt() : nullptr;
311 }
312 
313 }  // namespace cuda
314 }  // namespace stream_executor
315