1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ 17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ 18 19 #include <algorithm> 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <vector> 24 25 #include "third_party/gpus/cuda/include/cublasLt.h" 26 #include "third_party/gpus/cuda/include/cublas_v2.h" 27 #include "third_party/gpus/cuda/include/cuda.h" 28 #include "tensorflow/compiler/xla/stream_executor/blas.h" 29 #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_utils.h" 30 #include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" 31 #include "tensorflow/compiler/xla/stream_executor/lib/status.h" 32 33 namespace stream_executor { 34 namespace gpu { 35 class GpuExecutor; 36 } // namespace gpu 37 38 namespace cuda { 39 40 class BlasLt { 41 template <typename T> 42 using Owned = 43 std::unique_ptr<std::remove_pointer_t<T>, cublasStatus_t (*)(T)>; 44 45 public: 46 class MatrixLayout { 47 public: 48 enum class Order { kRowMajor, kColumnMajor }; 49 50 // If `leading_dim_stride` is not specified, it defaults to: 51 // - `num_cols` if `order == kRowMajor`, 52 // - `num_rows` if `order == kColumnMajor`. 53 // If `batch_stride` is not specified, it defaults to `num_rows * num_cols` 54 // if `batch_size > 1`, otherwise `0`. 55 static port::StatusOr<MatrixLayout> Create( 56 blas::DataType type, size_t num_rows, size_t num_cols, Order order, 57 size_t batch_size = 1, 58 std::optional<int64_t> leading_dim_stride = std::nullopt, 59 std::optional<int64_t> batch_stride = std::nullopt); 60 61 cudaDataType_t type() const; 62 get()63 cublasLtMatrixLayout_t get() const { return handle_.get(); } 64 65 private: MatrixLayout(cublasLtMatrixLayout_t handle)66 explicit MatrixLayout(cublasLtMatrixLayout_t handle) 67 : handle_(handle, cublasLtMatrixLayoutDestroy) {} 68 69 Owned<cublasLtMatrixLayout_t> handle_; 70 }; 71 72 enum class Epilogue { 73 kDefault = 1, // No special postprocessing 74 kReLU = 2, // Apply point-wise ReLU function 75 kBias = 4, // Add broadcasted bias vector 76 kBiasThenReLU = kBias | kReLU, // Apply bias and then ReLU transform 77 kGeLU = 32, // Apply GELU point-wise transform to the results 78 kBiasThenGeLUApproximate = 79 kBias | kGeLU, // Apply bias and then GeLU Tanh transform 80 }; 81 82 // Describes the location of pointers for the scaling factors alpha and beta. 83 enum class PointerMode { 84 kHost, 85 kDevice, 86 }; 87 88 class MatmulDesc { 89 public: 90 static port::StatusOr<MatmulDesc> Create( 91 blas::ComputationType compute_type, blas::DataType scale_type, 92 blas::Transpose trans_a = blas::Transpose::kNoTranspose, 93 blas::Transpose trans_b = blas::Transpose::kNoTranspose, 94 Epilogue epilogue = Epilogue::kDefault, 95 PointerMode pointer_mode = PointerMode::kHost); 96 97 cublasComputeType_t compute_type() const; 98 cudaDataType_t scale_type() const; 99 cublasLtPointerMode_t pointer_mode() const; 100 get()101 cublasLtMatmulDesc_t get() const { return handle_.get(); } 102 103 private: MatmulDesc(cublasLtMatmulDesc_t handle)104 explicit MatmulDesc(cublasLtMatmulDesc_t handle) 105 : handle_(handle, cublasLtMatmulDescDestroy) {} 106 107 Owned<cublasLtMatmulDesc_t> handle_; 108 }; 109 110 // TODO(cjfj): Add consistency checks for types, shapes, etc.? 111 struct MatmulPlan { 112 MatmulDesc op_desc; 113 MatrixLayout a_desc; 114 MatrixLayout b_desc; 115 MatrixLayout c_desc; 116 MatrixLayout d_desc; 117 }; 118 119 class MatmulPreference { 120 public: 121 static port::StatusOr<MatmulPreference> Create(size_t max_workspace_size); 122 get()123 cublasLtMatmulPreference_t get() const { return handle_.get(); } 124 125 private: MatmulPreference(cublasLtMatmulPreference_t handle)126 explicit MatmulPreference(cublasLtMatmulPreference_t handle) 127 : handle_(handle, cublasLtMatmulPreferenceDestroy) {} 128 129 Owned<cublasLtMatmulPreference_t> handle_; 130 }; 131 132 struct MatmulAlgorithm { 133 cublasLtMatmulAlgo_t algo; 134 size_t workspace_size; 135 }; 136 BlasLt(gpu::GpuExecutor * parent)137 explicit BlasLt(gpu::GpuExecutor* parent) 138 : parent_(parent), blas_lt_(nullptr, cublasLtDestroy) {} 139 140 port::Status Init(); 141 142 // Returns the type for the alpha and beta scalars. 143 static blas::DataType GetScaleType(blas::DataType c_type, 144 blas::ComputationType computation_type); 145 146 // Returns a list of supported algorithms for DoMatmul. The algorithms are 147 // returned in the order of increasing estimated compute time according to an 148 // internal heuristic. 149 port::StatusOr<std::vector<MatmulAlgorithm>> GetMatmulAlgorithms( 150 const MatmulPlan& plan, const MatmulPreference& preference, 151 size_t max_algorithm_count = 128); 152 153 template <typename AB, typename CD, typename Scale> 154 port::Status DoMatmul(Stream* stream, const MatmulPlan& plan, 155 const HostOrDeviceScalar<Scale>& alpha, 156 const DeviceMemory<AB>& a, const DeviceMemory<AB>& b, 157 const HostOrDeviceScalar<Scale>& beta, 158 const DeviceMemory<CD>& c, DeviceMemory<CD>& d, 159 const MatmulAlgorithm& algorithm, 160 ScratchAllocator& scratch_allocator, 161 const DeviceMemory<CD>& bias = {}, 162 blas::ProfileResult* profile_result = nullptr) { 163 if (AsCudaDataType(blas::ToDataType<Scale>::value) != 164 plan.op_desc.scale_type()) { 165 return port::InvalidArgumentError("mismatched scale types"); 166 } 167 168 bool expect_scale_factor_on_device = 169 (plan.op_desc.pointer_mode() == CUBLASLT_POINTER_MODE_DEVICE); 170 171 if (alpha.on_device() != expect_scale_factor_on_device) { 172 return port::InvalidArgumentError("wrong location for alpha"); 173 } 174 175 if (beta.on_device() != expect_scale_factor_on_device) { 176 return port::InvalidArgumentError("wrong location for beta"); 177 } 178 179 if (AsCudaDataType(blas::ToDataType<AB>::value) != plan.a_desc.type()) { 180 return port::InvalidArgumentError("mismatched A matrix types"); 181 } 182 183 if (AsCudaDataType(blas::ToDataType<AB>::value) != plan.b_desc.type()) { 184 return port::InvalidArgumentError("mismatched B matrix types"); 185 } 186 187 if (AsCudaDataType(blas::ToDataType<CD>::value) != plan.c_desc.type()) { 188 return port::InvalidArgumentError("mismatched C matrix types"); 189 } 190 191 if (AsCudaDataType(blas::ToDataType<CD>::value) != plan.d_desc.type()) { 192 return port::InvalidArgumentError("mismatched D matrix types"); 193 } 194 195 return DoMatmul(stream, plan, alpha.opaque(), a, b, beta.opaque(), c, d, 196 algorithm, scratch_allocator, bias, profile_result); 197 } 198 199 private: 200 port::Status DoMatmul(Stream* stream, const MatmulPlan& plan, 201 const void* alpha, DeviceMemoryBase a, 202 DeviceMemoryBase b, const void* beta, 203 DeviceMemoryBase c, DeviceMemoryBase d, 204 const MatmulAlgorithm& algorithm, 205 ScratchAllocator& scratch_allocator, 206 DeviceMemoryBase bias, 207 blas::ProfileResult* profile_result); 208 209 gpu::GpuExecutor* parent_; 210 211 absl::Mutex mu_; 212 Owned<cublasLtHandle_t> blas_lt_ ABSL_GUARDED_BY(mu_); 213 }; 214 215 // Returns `BlasLt` implementation for a stream if available, or `nullptr`. 216 BlasLt* GetBlasLt(Stream* stream); 217 218 } // namespace cuda 219 } // namespace stream_executor 220 221 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_BLAS_LT_H_ 222