1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h"
17
18 #include <algorithm>
19 #include <climits>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "third_party/gpus/cuda/include/cublasLt.h"
27 #include "third_party/gpus/cuda/include/cublas_v2.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/stream_executor/blas.h"
30 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h"
31 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h"
32 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_activation.h"
33 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_helpers.h"
34 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
35 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h"
36 #include "tensorflow/compiler/xla/stream_executor/scratch_allocator.h"
37 #include "tensorflow/compiler/xla/stream_executor/stream.h"
38
39 #define SET_ATTR(setter, handle, attr, value) \
40 ToStatus(setter(handle, attr, &value, sizeof(decltype(value))), #setter)
41
42 #define GET_ATTR(getter, handle, attr, ValueT) \
43 [&]() -> port::StatusOr<ValueT> { \
44 ValueT value; \
45 TF_RETURN_IF_ERROR(ToStatus( \
46 getter(handle, attr, &value, sizeof(ValueT), nullptr), #getter)); \
47 return std::move(value); \
48 }()
49
50 namespace stream_executor {
51 namespace cuda {
52 namespace {
53
54 template <typename T>
SetAttr(cublasLtMatrixLayout_t handle,cublasLtMatrixLayoutAttribute_t attr,T value)55 port::Status SetAttr(cublasLtMatrixLayout_t handle,
56 cublasLtMatrixLayoutAttribute_t attr, T value) {
57 return SET_ATTR(cublasLtMatrixLayoutSetAttribute, handle, attr, value);
58 }
59
60 template <typename T>
GetAttr(cublasLtMatrixLayout_t handle,cublasLtMatrixLayoutAttribute_t attr)61 port::StatusOr<T> GetAttr(cublasLtMatrixLayout_t handle,
62 cublasLtMatrixLayoutAttribute_t attr) {
63 return GET_ATTR(cublasLtMatrixLayoutGetAttribute, handle, attr, T);
64 }
65
66 template <typename T>
SetAttr(cublasLtMatmulDesc_t handle,cublasLtMatmulDescAttributes_t attr,T value)67 port::Status SetAttr(cublasLtMatmulDesc_t handle,
68 cublasLtMatmulDescAttributes_t attr, T value) {
69 return SET_ATTR(cublasLtMatmulDescSetAttribute, handle, attr, value);
70 }
71
72 template <typename T>
GetAttr(cublasLtMatmulDesc_t handle,cublasLtMatmulDescAttributes_t attr)73 port::StatusOr<T> GetAttr(cublasLtMatmulDesc_t handle,
74 cublasLtMatmulDescAttributes_t attr) {
75 return GET_ATTR(cublasLtMatmulDescGetAttribute, handle, attr, T);
76 }
77
78 template <typename T>
SetAttr(cublasLtMatmulPreference_t handle,cublasLtMatmulPreferenceAttributes_t attr,T value)79 port::Status SetAttr(cublasLtMatmulPreference_t handle,
80 cublasLtMatmulPreferenceAttributes_t attr, T value) {
81 return SET_ATTR(cublasLtMatmulPreferenceSetAttribute, handle, attr, value);
82 }
83
AsCublasLtPointerMode(BlasLt::PointerMode pointer_mode)84 cublasLtPointerMode_t AsCublasLtPointerMode(BlasLt::PointerMode pointer_mode) {
85 switch (pointer_mode) {
86 case BlasLt::PointerMode::kHost:
87 return CUBLASLT_POINTER_MODE_HOST;
88 case BlasLt::PointerMode::kDevice:
89 return CUBLASLT_POINTER_MODE_DEVICE;
90 }
91 }
92
AsCublasLtEpilogue(BlasLt::Epilogue epilogue)93 port::StatusOr<cublasLtEpilogue_t> AsCublasLtEpilogue(
94 BlasLt::Epilogue epilogue) {
95 switch (epilogue) {
96 case BlasLt::Epilogue::kDefault:
97 return CUBLASLT_EPILOGUE_DEFAULT;
98 case BlasLt::Epilogue::kReLU:
99 return CUBLASLT_EPILOGUE_RELU;
100 case BlasLt::Epilogue::kBias:
101 return CUBLASLT_EPILOGUE_BIAS;
102 case BlasLt::Epilogue::kBiasThenReLU:
103 return CUBLASLT_EPILOGUE_RELU_BIAS;
104 case BlasLt::Epilogue::kGeLU:
105 #if CUDA_VERSION >= 11040
106 return CUBLASLT_EPILOGUE_GELU;
107 #else
108 return port::InternalError(absl::StrCat(
109 "CUBLASLT_EPILOGUE_GELU epilog requires cublasLt >= 11.4"));
110 #endif
111 case BlasLt::Epilogue::kBiasThenGeLUApproximate:
112 #if CUDA_VERSION >= 11040
113 return CUBLASLT_EPILOGUE_GELU_BIAS;
114 #else
115 return port::InternalError(absl::StrCat(
116 "CUBLASLT_EPILOGUE_GELU_BIAS epilog requires cublasLt >= 11.4"));
117 #endif
118 }
119 }
120
121 } // namespace
122
Init()123 port::Status BlasLt::Init() {
124 cublasLtHandle_t blas_lt;
125 SE_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&blas_lt));
126 absl::MutexLock lock(&mu_);
127 blas_lt_.reset(blas_lt);
128 return port::Status::OK();
129 }
130
GetScaleType(blas::DataType c_type,blas::ComputationType computation_type)131 /*static*/ blas::DataType BlasLt::GetScaleType(
132 blas::DataType c_type, blas::ComputationType computation_type) {
133 return ((computation_type == blas::ComputationType::kF32) &&
134 (c_type != blas::DataType::kComplexFloat))
135 ? blas::DataType::kFloat
136 : c_type;
137 }
138
Create(blas::DataType type,size_t num_rows,size_t num_cols,BlasLt::MatrixLayout::Order order,size_t batch_size,std::optional<int64_t> leading_dim_stride,std::optional<int64_t> batch_stride)139 /*static*/ port::StatusOr<BlasLt::MatrixLayout> BlasLt::MatrixLayout::Create(
140 blas::DataType type, size_t num_rows, size_t num_cols,
141 BlasLt::MatrixLayout::Order order, size_t batch_size,
142 std::optional<int64_t> leading_dim_stride,
143 std::optional<int64_t> batch_stride) {
144 if (!leading_dim_stride) {
145 leading_dim_stride = (order == Order::kRowMajor) ? num_cols : num_rows;
146 }
147
148 cublasLtMatrixLayout_t cu_layout;
149 SE_CUBLAS_RETURN_IF_ERROR(
150 cublasLtMatrixLayoutCreate(&cu_layout, AsCudaDataType(type), num_rows,
151 num_cols, *leading_dim_stride));
152 // Wrap cublas handle immediately, so it is cleaned up if an error occurs.
153 BlasLt::MatrixLayout layout(cu_layout);
154 TF_RETURN_IF_ERROR(
155 SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_ORDER,
156 int32_t{(order == Order::kRowMajor) ? CUBLASLT_ORDER_ROW
157 : CUBLASLT_ORDER_COL}));
158 TF_RETURN_IF_ERROR(SetAttr(cu_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
159 static_cast<int32_t>(batch_size)));
160
161 if (!batch_stride) {
162 batch_stride = (batch_size > 1) ? num_rows * num_cols : 0;
163 }
164
165 TF_RETURN_IF_ERROR(SetAttr(
166 cu_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, *batch_stride));
167 return std::move(layout);
168 }
169
type() const170 cudaDataType_t BlasLt::MatrixLayout::type() const {
171 return static_cast<cudaDataType_t>(
172 GetAttr<uint32_t>(handle_.get(), CUBLASLT_MATRIX_LAYOUT_TYPE)
173 .ValueOrDie());
174 }
175
Create(blas::ComputationType compute_type,blas::DataType scale_type,blas::Transpose trans_a,blas::Transpose trans_b,BlasLt::Epilogue epilogue,BlasLt::PointerMode pointer_mode)176 /*static*/ port::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
177 blas::ComputationType compute_type, blas::DataType scale_type,
178 blas::Transpose trans_a, blas::Transpose trans_b, BlasLt::Epilogue epilogue,
179 BlasLt::PointerMode pointer_mode) {
180 cublasLtMatmulDesc_t cu_desc;
181 SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescCreate(
182 &cu_desc, AsCublasComputeType(compute_type), AsCudaDataType(scale_type)));
183 // Wrap cublas handle immediately, so it is cleaned up if an error occurs.
184 BlasLt::MatmulDesc desc(cu_desc);
185 TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
186 AsCublasLtPointerMode(pointer_mode)));
187 TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSA,
188 AsCublasOperation(trans_a)));
189 TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_TRANSB,
190 AsCublasOperation(trans_b)));
191 TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
192 TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
193 return std::move(desc);
194 }
195
compute_type() const196 cublasComputeType_t BlasLt::MatmulDesc::compute_type() const {
197 return static_cast<cublasComputeType_t>(
198 GetAttr<int32_t>(handle_.get(), CUBLASLT_MATMUL_DESC_COMPUTE_TYPE)
199 .ValueOrDie());
200 }
201
scale_type() const202 cudaDataType_t BlasLt::MatmulDesc::scale_type() const {
203 return static_cast<cudaDataType_t>(
204 GetAttr<int32_t>(handle_.get(), CUBLASLT_MATMUL_DESC_SCALE_TYPE)
205 .ValueOrDie());
206 }
207
pointer_mode() const208 cublasLtPointerMode_t BlasLt::MatmulDesc::pointer_mode() const {
209 return static_cast<cublasLtPointerMode_t>(
210 GetAttr<int32_t>(handle_.get(), CUBLASLT_MATMUL_DESC_POINTER_MODE)
211 .ValueOrDie());
212 }
213
214 /*static*/ port::StatusOr<BlasLt::MatmulPreference>
Create(size_t max_workspace_size)215 BlasLt::MatmulPreference::Create(size_t max_workspace_size) {
216 cublasLtMatmulPreference_t cu_preference;
217 SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceCreate(&cu_preference));
218 // Wrap cublas handle immediately, so it is cleaned up if an error occurs.
219 BlasLt::MatmulPreference preference(cu_preference);
220 TF_RETURN_IF_ERROR(SetAttr<uint64_t>(cu_preference,
221 CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
222 max_workspace_size));
223 return std::move(preference);
224 }
225
226 port::StatusOr<std::vector<BlasLt::MatmulAlgorithm>>
GetMatmulAlgorithms(const BlasLt::MatmulPlan & plan,const BlasLt::MatmulPreference & preference,size_t max_algorithm_count)227 BlasLt::GetMatmulAlgorithms(const BlasLt::MatmulPlan& plan,
228 const BlasLt::MatmulPreference& preference,
229 size_t max_algorithm_count) {
230 max_algorithm_count = std::min(max_algorithm_count, size_t{INT_MAX});
231 std::vector<cublasLtMatmulHeuristicResult_t> results(max_algorithm_count);
232 {
233 absl::MutexLock lock(&mu_);
234 TF_RET_CHECK(blas_lt_ != nullptr);
235
236 gpu::ScopedActivateExecutorContext sac{parent_};
237
238 int found_algorithm_count = 0;
239 SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulAlgoGetHeuristic(
240 blas_lt_.get(), plan.op_desc.get(), plan.a_desc.get(),
241 plan.b_desc.get(), plan.c_desc.get(), plan.d_desc.get(),
242 preference.get(), max_algorithm_count, results.data(),
243 &found_algorithm_count));
244 results.resize(found_algorithm_count);
245 }
246
247 std::vector<BlasLt::MatmulAlgorithm> algorithms;
248 algorithms.reserve(results.size());
249 for (const cublasLtMatmulHeuristicResult_t& result : results) {
250 if (result.state == CUBLAS_STATUS_SUCCESS) { // Skip failed algos.
251 algorithms.push_back({result.algo, result.workspaceSize});
252 }
253 }
254 return std::move(algorithms);
255 }
256
DoMatmul(Stream * stream,const BlasLt::MatmulPlan & plan,const void * alpha,DeviceMemoryBase a,DeviceMemoryBase b,const void * beta,DeviceMemoryBase c,DeviceMemoryBase d,const BlasLt::MatmulAlgorithm & algorithm,ScratchAllocator & scratch_allocator,DeviceMemoryBase bias,blas::ProfileResult * profile_result)257 port::Status BlasLt::DoMatmul(Stream* stream, const BlasLt::MatmulPlan& plan,
258 const void* alpha, DeviceMemoryBase a,
259 DeviceMemoryBase b, const void* beta,
260 DeviceMemoryBase c, DeviceMemoryBase d,
261 const BlasLt::MatmulAlgorithm& algorithm,
262 ScratchAllocator& scratch_allocator,
263 DeviceMemoryBase bias,
264 blas::ProfileResult* profile_result) {
265 std::unique_ptr<gpu::GpuTimer, gpu::GpuTimerDeleter> timer;
266 if (profile_result != nullptr) {
267 timer.reset(new gpu::GpuTimer(parent_));
268 TF_RET_CHECK(timer->Init());
269 TF_RET_CHECK(timer->Start(gpu::AsGpuStream(stream)));
270 }
271
272 void* workspace = nullptr;
273 if (algorithm.workspace_size > 0) {
274 TF_ASSIGN_OR_RETURN(
275 DeviceMemory<uint8_t> alloc,
276 scratch_allocator.AllocateBytes(algorithm.workspace_size));
277 workspace = gpu::GpuMemoryMutable(&alloc);
278 }
279
280 {
281 absl::MutexLock lock(&mu_);
282 TF_RET_CHECK(blas_lt_ != nullptr);
283 // We must set the bias pointer while holding the mutex, to avoid a
284 // potential race condition from multiple threads sharing the same plan.
285 if (bias != nullptr) {
286 TF_RETURN_IF_ERROR(SetAttr(plan.op_desc.get(),
287 CUBLASLT_MATMUL_DESC_BIAS_POINTER,
288 bias.opaque()));
289 }
290
291 gpu::ScopedActivateExecutorContext sac{parent_};
292
293 SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmul(
294 blas_lt_.get(), plan.op_desc.get(), alpha, a.opaque(),
295 plan.a_desc.get(), b.opaque(), plan.b_desc.get(), beta, c.opaque(),
296 plan.c_desc.get(), d.opaque(), plan.d_desc.get(), &algorithm.algo,
297 workspace, algorithm.workspace_size, gpu::AsGpuStreamValue(stream)));
298 }
299
300 if (timer) {
301 TF_RET_CHECK(timer->Stop(gpu::AsGpuStream(stream)));
302 profile_result->set_is_valid(true);
303 profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
304 }
305 return port::Status::OK();
306 }
307
GetBlasLt(Stream * stream)308 BlasLt* GetBlasLt(Stream* stream) {
309 CUDABlas* blas = dynamic_cast<CUDABlas*>(stream->parent()->AsBlas());
310 return (blas != nullptr) ? &blas->blas_lt() : nullptr;
311 }
312
313 } // namespace cuda
314 } // namespace stream_executor
315