1 // Original TunableOp is from onnxruntime. 2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h 3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable 4 // Copyright (c) Microsoft Corporation. 5 // Licensed under the MIT license. 6 // 7 // Adapting TunableOp into PyTorch 8 // Copyright (c) Advanced Micro Devices, Inc. 9 // 10 #pragma once 11 12 #include <ATen/cuda/tunable/Tunable.h> 13 #include <ATen/cuda/Sleep.h> 14 #include <c10/cuda/CUDACachingAllocator.h> 15 16 #ifndef _WIN32 17 #include <cxxabi.h> 18 #endif 19 20 #include <string> 21 #include <type_traits> 22 #include <unordered_map> 23 #include <vector> 24 25 namespace at::cuda::tunable { 26 27 template <typename ParamsT> 28 class Callable { 29 public: 30 Callable() = default; 31 Callable(Callable&&) = default; 32 virtual ~Callable() = default; Call(const ParamsT *)33 virtual TuningStatus Call(const ParamsT*) { 34 return FAIL; 35 } IsSupported(const ParamsT * params)36 virtual TuningStatus IsSupported(const ParamsT* params) { 37 return Call(params); 38 } 39 }; 40 41 template <typename ParamsT, typename TimerT> 42 class TunableOp { 43 public: 44 TunableOp() = default; 45 TunableOp(TunableOp&&) = default; 46 virtual ~TunableOp() = default; 47 operator()48 TuningStatus operator()(const ParamsT* params) { 49 ResultEntry result = ResultEntry::Null(); 50 TuningContext* ctx = getTuningContext(); 51 if (ctx->IsTunableOpEnabled()) { 52 auto& mgr = ctx->GetTuningResultsManager(); 53 auto op_sig = Signature(); 54 auto params_sig = params->Signature(); 55 result = mgr.Lookup(op_sig, params_sig); 56 // If there is not previous tuning result been found, we do the tuning iff tuning is enabled 57 if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) { 58 result = FindFastest(params); 59 mgr.Add(op_sig, params_sig, result); 60 } 61 } 62 else { 63 result = ResultEntry::Default(); 64 } 65 if (result == ResultEntry::Null()) { 66 TUNABLE_LOG2("no result, using default"); 67 result = ResultEntry::Default(); 68 } 69 auto iter = ops_.find(result); 70 TORCH_CHECK(iter != ops_.end()); 71 return iter->second->Call(params); 72 } 73 Signature()74 virtual std::string Signature() { 75 // According to C++17 standard https://wg21.link/n4659 section 15.7.4 76 // > if the operand of typeid refers to the 77 // > object under construction or destruction, typeid yields the std::type_info object representing the constructor 78 // > or destructor’s class. 79 // So delay the op signature generation. 80 c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); }); 81 return signature_; 82 } 83 84 protected: RegisterOp(const std::string & name,std::unique_ptr<Callable<ParamsT>> op)85 void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) { 86 this->op_names_.emplace_back(name); 87 this->ops_.emplace(name, std::move(op)); 88 } 89 90 private: WarmUp(Callable<ParamsT> * op,const std::vector<ParamsT * > & param,size_t num_iter,size_t & offset)91 static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) { 92 TuningContext* ctx = getTuningContext(); 93 bool do_flush = ctx->IsICacheFlushEnabled(); 94 for (size_t i = 0; i < num_iter; i++) { 95 if (do_flush) { 96 at::cuda::flush_icache(); 97 } 98 TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); 99 } 100 } 101 Profile(Callable<ParamsT> * op,const std::vector<ParamsT * > & param,size_t num_iter,size_t & offset)102 static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) { 103 TuningContext* ctx = getTuningContext(); 104 bool do_flush = ctx->IsICacheFlushEnabled(); 105 TimerT timer{}; 106 timer.Start(); 107 for (size_t i = 0; i < num_iter; i++) { 108 if (do_flush) { 109 at::cuda::flush_icache(); 110 } 111 TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK); 112 } 113 timer.End(); 114 return timer.Duration() / num_iter; 115 } 116 117 protected: FindFastest(const ParamsT * params)118 virtual ResultEntry FindFastest(const ParamsT* params) { 119 TuningContext* ctx = getTuningContext(); 120 auto op_sig = Signature(); 121 auto params_sig = params->Signature(); 122 TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates"); 123 auto min_duration_ms = std::numeric_limits<double>::infinity(); 124 std::string id_name = "Default"; 125 ParamsT* reference_params = nullptr; 126 127 // numeric check option is controlled by non-static env var, so check it once per tuned operator 128 bool do_numerics_check = ctx->IsNumericsCheckEnabled(); 129 130 // calcaulte a reference answer for numerical check 131 if (do_numerics_check) { 132 reference_params = params->DeepCopy(false); 133 TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK); 134 } 135 136 // need copies of params to reuse 137 // make as many copies as will fill the requested rotating buffer size, if requested 138 // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int 139 size_t rotating_size = ctx->GetRotatingBufferSize(); 140 bool use_buffer_rotation = (rotating_size > 0); 141 size_t param_size = params->GetSize(use_buffer_rotation); 142 size_t param_count = (rotating_size / param_size) + 1; 143 constexpr size_t MB = 1024*1024; 144 if (use_buffer_rotation) { 145 TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", 146 "Needed Size: ", param_size/MB, " MiB. ", 147 "Needed number of param copies: ", param_count); 148 } 149 TORCH_CHECK(param_count > 0); 150 151 std::vector<ParamsT*> reusable_params(param_count); 152 for (size_t i = 0; i < param_count; i++) { 153 reusable_params[i] = params->DeepCopy(use_buffer_rotation); 154 } 155 156 // for rotating buffer 157 size_t offset = 0; 158 159 for (size_t i = 0; i < op_names_.size(); i++) { 160 auto* candidate = ops_[op_names_[i]].get(); // borrow pointer 161 162 if (do_numerics_check) { 163 ParamsT* numerical_params = params->DeepCopy(false); 164 auto status = candidate->Call(numerical_params); 165 if (status != OK) { 166 numerical_params->Delete(); 167 TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); 168 continue; 169 } 170 status = reference_params->NumericalCheck(numerical_params); 171 numerical_params->Delete(); 172 if (status != OK) { 173 TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); 174 continue; 175 } 176 } 177 else { 178 auto status = candidate->Call(reusable_params[0]); 179 if (status != OK) { 180 TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); 181 continue; 182 } 183 } 184 185 // collect a small profile 186 constexpr const int approx_num_iter = 3; 187 auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset); 188 // bail if too slow 189 if (approx_duration > 2 * min_duration_ms) { 190 TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]); 191 continue; 192 } 193 194 // for warmup does user set max duration, max iters, or both? 195 // warmup is allowed to be skipped by setting either iterations or duration to 0 196 double max_warmup_duration = ctx->GetMaxWarmupDurationMs(); 197 int max_warmup_iter = ctx->GetMaxWarmupIterations(); 198 int warmup_iter = 1; // default 199 if (max_warmup_duration >= 0) { 200 int duration_iters = max_warmup_duration / approx_duration; 201 if (max_warmup_iter >= 0) { 202 warmup_iter = std::min(max_warmup_iter, duration_iters); 203 } 204 else { 205 warmup_iter = duration_iters; 206 } 207 } 208 else if (max_warmup_iter >= 0) { 209 warmup_iter = max_warmup_iter; 210 } 211 212 // for tuning does user set max duration, max iters, or both? 213 double max_tuning_duration = ctx->GetMaxTuningDurationMs(); 214 int max_tuning_iter = ctx->GetMaxTuningIterations(); 215 int tuning_iter = 100; // default 216 if (max_tuning_duration > 0) { 217 int duration_iters = max_tuning_duration / approx_duration; 218 if (max_tuning_iter > 0) { 219 tuning_iter = std::min(max_tuning_iter, duration_iters); 220 } 221 else { 222 tuning_iter = duration_iters; 223 } 224 } 225 else if (max_tuning_iter > 0) { 226 tuning_iter = max_tuning_iter; 227 } 228 // tuning must run at least 1 iteration 229 tuning_iter = std::max(1, tuning_iter); 230 231 // do the full warmup followed by tuning 232 double warmup_ms = warmup_iter * approx_duration; 233 double tuning_ms = tuning_iter * approx_duration; 234 TUNABLE_LOG3("├──tuning using " 235 "warmup iters ", warmup_iter, " [", warmup_ms, " ms] " 236 "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ", 237 "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]); 238 TUNABLE_LOG3("├──offset at ", offset); 239 WarmUp(candidate, reusable_params, warmup_iter, offset); 240 auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset); 241 if (duration_ms < min_duration_ms) { 242 TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]); 243 min_duration_ms = duration_ms; 244 id_name = op_names_[i]; 245 } 246 } 247 248 for (size_t i = 0; i < reusable_params.size(); i++) { 249 reusable_params[i]->Delete(); 250 } 251 if (reference_params) { 252 reference_params->Delete(); 253 } 254 255 TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name); 256 return ResultEntry(id_name, min_duration_ms); 257 } 258 259 private: CreateSignature()260 std::string CreateSignature() { 261 #ifndef _WIN32 262 const auto* name = typeid(*this).name(); 263 char buf[256]; 264 size_t buf_len = 256; 265 abi::__cxa_demangle(name, buf, &buf_len, nullptr); 266 buf[255] = '\0'; 267 return buf; 268 #else 269 return typeid(*this).name(); 270 #endif 271 } 272 273 mutable c10::once_flag signature_init_once_; 274 std::string signature_; 275 276 std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_; 277 std::vector<std::string> op_names_; 278 }; 279 280 struct OpParams { OpParamsOpParams281 OpParams() {} 282 virtual ~OpParams() = default; 283 virtual std::string Signature() const = 0; 284 }; 285 286 } // namespace at::cuda::tunable 287