1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11
12 #include <string>
13
14 #include <ATen/cuda/tunable/TunableOp.h>
15 #include <ATen/cuda/Exceptions.h>
16 #include <c10/util/StringUtil.h>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #include <ATen/NativeFunctions.h>
21 #else
22 #include <ATen/ops/allclose.h>
23 #include <ATen/ops/from_blob.h>
24 #endif
25
26 namespace at::cuda::tunable {
27
28 enum class BlasOp {
29 N = 0,
30 T = 1
31 };
32
BlasOpToString(BlasOp op)33 inline std::string BlasOpToString(BlasOp op) {
34 switch (op) {
35 case BlasOp::N:
36 return "N";
37 case BlasOp::T:
38 return "T";
39 }
40 TORCH_CHECK(false, "unrecognized BlasOp");
41 return "N";
42 }
43
44 namespace detail {
45
NumericalCheck(ScalarType dtype,void * c,void * other_c,int64_t size)46 static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
47 auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
48 // comparison done as 1D tensor
49 at::Tensor ref = at::from_blob(c, {size}, options);
50 at::Tensor oth = at::from_blob(other_c, {size}, options);
51 at::Tensor ref_float = ref.to(at::kFloat);
52 at::Tensor oth_float = oth.to(at::kFloat);
53 std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
54 std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
55 double last_succeed_atol = 1;
56 double last_succeed_rtol = 1;
57 for (auto& atol : atols) {
58 for (auto& rtol : rtols) {
59 if (at::allclose(ref_float, oth_float, rtol, atol)) {
60 last_succeed_atol = atol;
61 last_succeed_rtol = rtol;
62 }
63 }
64 }
65 if (last_succeed_atol == 1) {
66 return false;
67 }
68 else {
69 TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
70 }
71
72 return true;
73 }
74
75 }
76
77 template <typename T>
78 struct GemmParams : OpParams {
GemmParamsGemmParams79 GemmParams() {
80 duplicate_inputs_ = false;
81 }
82
SignatureGemmParams83 std::string Signature() const override {
84 return c10::str(transa, transb, "_", m, "_", n, "_", k);
85 }
86
GetSizeAGemmParams87 size_t GetSizeA() const {
88 return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
89 }
90
GetSizeBGemmParams91 size_t GetSizeB() const {
92 return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
93 }
94
GetSizeCGemmParams95 size_t GetSizeC() const {
96 return sizeof(T) * ldc * n;
97 }
98
GetSizeGemmParams99 size_t GetSize(bool duplicate_inputs) const {
100 size_t size = GetSizeC();
101 if (duplicate_inputs) {
102 size += GetSizeA();
103 size += GetSizeB();
104 }
105 return size;
106 }
107
DeepCopyGemmParams108 GemmParams* DeepCopy(bool duplicate_inputs) const {
109 GemmParams* copy = new GemmParams;
110 *copy = *this;
111 c10::DeviceIndex device = 0;
112 AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
113 size_t c_size = GetSizeC();
114 copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
115 AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
116 copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
117 if (duplicate_inputs) {
118 size_t a_size = GetSizeA();
119 size_t b_size = GetSizeB();
120 copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
121 copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
122 copy->duplicate_inputs_ = true;
123 }
124 return copy;
125 }
126
127 // only call on object returned by DeepCopy
DeleteGemmParams128 void Delete() {
129 c10::cuda::CUDACachingAllocator::raw_delete(c);
130 if (duplicate_inputs_) {
131 c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
132 c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
133 }
134 }
135
NumericalCheckGemmParams136 TuningStatus NumericalCheck(GemmParams<T> *other) {
137 auto c_dtype = c10::CppTypeToScalarType<T>::value;
138 return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
139 }
140
141 char transa;
142 char transb;
143 int64_t m;
144 int64_t n;
145 int64_t k;
146 at::opmath_type<T> alpha;
147 const T* a;
148 int64_t lda;
149 const T* b;
150 int64_t ldb;
151 at::opmath_type<T> beta;
152 T* c;
153 int64_t ldc;
154 private:
155 bool duplicate_inputs_;
156 };
157
158 template <typename T>
159 struct GemmStridedBatchedParams : OpParams {
GemmStridedBatchedParamsGemmStridedBatchedParams160 GemmStridedBatchedParams() {
161 duplicate_inputs_ = false;
162 }
163
SignatureGemmStridedBatchedParams164 std::string Signature() const override {
165 return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
166 }
167
GetSizeAGemmStridedBatchedParams168 size_t GetSizeA() const {
169 return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m) * batch;
170 }
171
GetSizeBGemmStridedBatchedParams172 size_t GetSizeB() const {
173 return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k) * batch;
174 }
175
GetSizeCGemmStridedBatchedParams176 size_t GetSizeC() const {
177 return sizeof(T) * ldc * n * batch;
178 }
179
GetSizeGemmStridedBatchedParams180 size_t GetSize(bool duplicate_inputs) const {
181 size_t size = GetSizeC();
182 if (duplicate_inputs) {
183 size += GetSizeA();
184 size += GetSizeB();
185 }
186 return size;
187 }
188
DeepCopyGemmStridedBatchedParams189 GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
190 GemmStridedBatchedParams* copy = new GemmStridedBatchedParams;
191 *copy = *this;
192 c10::DeviceIndex device = 0;
193 AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
194 size_t c_size = GetSizeC();
195 copy->c = static_cast<T*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
196 AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
197 copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
198 if (duplicate_inputs) {
199 size_t a_size = GetSizeA();
200 size_t b_size = GetSizeB();
201 copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
202 copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
203 copy->duplicate_inputs_ = true;
204 }
205 return copy;
206 }
207
208 // only call on object returned by DeepCopy
DeleteGemmStridedBatchedParams209 void Delete() {
210 c10::cuda::CUDACachingAllocator::raw_delete(c);
211 if (duplicate_inputs_) {
212 c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
213 c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
214 }
215 }
216
NumericalCheckGemmStridedBatchedParams217 TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
218 auto c_dtype = c10::CppTypeToScalarType<T>::value;
219 return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL;
220 }
221
222 char transa;
223 char transb;
224 int64_t m;
225 int64_t n;
226 int64_t k;
227 at::opmath_type<T> alpha;
228 const T* a;
229 int64_t lda;
230 int64_t stride_a;
231 const T* b;
232 int64_t ldb;
233 int64_t stride_b;
234 at::opmath_type<T> beta;
235 T* c;
236 int64_t ldc;
237 int64_t stride_c;
238 int64_t batch;
239 private:
240 bool duplicate_inputs_;
241 };
242
243 template <typename T>
244 struct ScaledGemmParams : OpParams {
ScaledGemmParamsScaledGemmParams245 ScaledGemmParams() {
246 duplicate_inputs_ = false;
247 }
248
SignatureScaledGemmParams249 std::string Signature() const override {
250 return c10::str(transa, transb, "_", m, "_", n, "_", k);
251 }
252
GetSizeAScaledGemmParams253 size_t GetSizeA() const {
254 return sizeof(T) * lda * ((transa == 'n' || transa == 'N') ? k : m);
255 }
256
GetSizeBScaledGemmParams257 size_t GetSizeB() const {
258 return sizeof(T) * ldb * ((transb == 'n' || transb == 'N') ? n : k);
259 }
260
GetSizeCScaledGemmParams261 size_t GetSizeC() const {
262 return sizeof(T) * ldc * n;
263 }
264
GetSizeScaledGemmParams265 size_t GetSize(bool duplicate_inputs) const {
266 size_t size = GetSizeC();
267 if (duplicate_inputs) {
268 size += GetSizeA();
269 size += GetSizeB();
270 }
271 return size;
272 }
273
DeepCopyScaledGemmParams274 ScaledGemmParams* DeepCopy(bool duplicate_inputs) const {
275 ScaledGemmParams* copy = new ScaledGemmParams;
276 *copy = *this;
277 c10::DeviceIndex device = 0;
278 AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
279 size_t c_size = GetSizeC();
280 copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
281 AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
282 copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
283 if (duplicate_inputs) {
284 size_t a_size = GetSizeA();
285 size_t b_size = GetSizeB();
286 copy->a = c10::cuda::CUDACachingAllocator::raw_alloc(a_size);
287 copy->b = c10::cuda::CUDACachingAllocator::raw_alloc(b_size);
288 copy->duplicate_inputs_ = true;
289 }
290 return copy;
291 }
292
293 // only call on object returned by DeepCopy
DeleteScaledGemmParams294 void Delete() {
295 c10::cuda::CUDACachingAllocator::raw_delete(c);
296 if (duplicate_inputs_) {
297 c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(a));
298 c10::cuda::CUDACachingAllocator::raw_delete(const_cast<void*>(b));
299 }
300 }
301
NumericalCheckScaledGemmParams302 TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
303 return detail::NumericalCheck(c_dtype, c, other->c, ldc*n) ? OK : FAIL;
304 }
305
306 char transa;
307 char transb;
308 int64_t m;
309 int64_t n;
310 int64_t k;
311 const void* a;
312 const void* a_scale_ptr;
313 int64_t lda;
314 ScalarType a_dtype;
315 const void* b;
316 const void* b_scale_ptr;
317 int64_t ldb;
318 ScalarType b_dtype;
319 const void* bias_ptr;
320 ScalarType bias_dtype;
321 void* c;
322 const void* c_scale_ptr;
323 int64_t ldc;
324 ScalarType c_dtype;
325 void* amax_ptr;
326 bool use_fast_accum;
327 private:
328 bool duplicate_inputs_;
329 };
330
331 } // namespace at::cuda::tunable
332