xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/TunableOp.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11 
12 #include <ATen/cuda/tunable/Tunable.h>
13 #include <ATen/cuda/Sleep.h>
14 #include <c10/cuda/CUDACachingAllocator.h>
15 
16 #ifndef _WIN32
17 #include <cxxabi.h>
18 #endif
19 
20 #include <string>
21 #include <type_traits>
22 #include <unordered_map>
23 #include <vector>
24 
25 namespace at::cuda::tunable {
26 
27 template <typename ParamsT>
28 class Callable {
29   public:
30     Callable() = default;
31     Callable(Callable&&) = default;
32     virtual ~Callable() = default;
Call(const ParamsT *)33     virtual TuningStatus Call(const ParamsT*) {
34       return FAIL;
35     }
IsSupported(const ParamsT * params)36     virtual TuningStatus IsSupported(const ParamsT* params) {
37       return Call(params);
38     }
39 };
40 
41 template <typename ParamsT, typename TimerT>
42 class TunableOp {
43   public:
44     TunableOp() = default;
45     TunableOp(TunableOp&&) = default;
46     virtual ~TunableOp() = default;
47 
operator()48     TuningStatus operator()(const ParamsT* params) {
49       ResultEntry result = ResultEntry::Null();
50       TuningContext* ctx = getTuningContext();
51       if (ctx->IsTunableOpEnabled()) {
52         auto& mgr = ctx->GetTuningResultsManager();
53         auto op_sig = Signature();
54         auto params_sig = params->Signature();
55         result = mgr.Lookup(op_sig, params_sig);
56         // If there is not previous tuning result been found, we do the tuning iff tuning is enabled
57         if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) {
58           result = FindFastest(params);
59           mgr.Add(op_sig, params_sig, result);
60         }
61       }
62       else {
63         result = ResultEntry::Default();
64       }
65       if (result == ResultEntry::Null()) {
66         TUNABLE_LOG2("no result, using default");
67         result = ResultEntry::Default();
68       }
69       auto iter = ops_.find(result);
70       TORCH_CHECK(iter != ops_.end());
71       return iter->second->Call(params);
72     }
73 
Signature()74     virtual std::string Signature() {
75       // According to C++17 standard https://wg21.link/n4659 section 15.7.4
76       // > if the operand of typeid refers to the
77       // > object under construction or destruction, typeid yields the std::type_info object representing the constructor
78       // > or destructor’s class.
79       // So delay the op signature generation.
80       c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
81       return signature_;
82     }
83 
84   protected:
RegisterOp(const std::string & name,std::unique_ptr<Callable<ParamsT>> op)85     void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
86       this->op_names_.emplace_back(name);
87       this->ops_.emplace(name, std::move(op));
88     }
89 
90   private:
WarmUp(Callable<ParamsT> * op,const std::vector<ParamsT * > & param,size_t num_iter,size_t & offset)91     static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
92       TuningContext* ctx = getTuningContext();
93       bool do_flush = ctx->IsICacheFlushEnabled();
94       for (size_t i = 0; i < num_iter; i++) {
95         if (do_flush) {
96           at::cuda::flush_icache();
97         }
98         TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
99       }
100     }
101 
Profile(Callable<ParamsT> * op,const std::vector<ParamsT * > & param,size_t num_iter,size_t & offset)102     static double Profile(Callable<ParamsT> *op, const std::vector<ParamsT*> &param, size_t num_iter, size_t &offset) {
103       TuningContext* ctx = getTuningContext();
104       bool do_flush = ctx->IsICacheFlushEnabled();
105       TimerT timer{};
106       timer.Start();
107       for (size_t i = 0; i < num_iter; i++) {
108         if (do_flush) {
109           at::cuda::flush_icache();
110         }
111         TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
112       }
113       timer.End();
114       return timer.Duration() / num_iter;
115     }
116 
117   protected:
FindFastest(const ParamsT * params)118     virtual ResultEntry FindFastest(const ParamsT* params) {
119       TuningContext* ctx = getTuningContext();
120       auto op_sig = Signature();
121       auto params_sig = params->Signature();
122       TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
123       auto min_duration_ms = std::numeric_limits<double>::infinity();
124       std::string id_name = "Default";
125       ParamsT* reference_params = nullptr;
126 
127       // numeric check option is controlled by non-static env var, so check it once per tuned operator
128       bool do_numerics_check = ctx->IsNumericsCheckEnabled();
129 
130       // calcaulte a reference answer for numerical check
131       if (do_numerics_check) {
132         reference_params = params->DeepCopy(false);
133         TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
134       }
135 
136       // need copies of params to reuse
137       // make as many copies as will fill the requested rotating buffer size, if requested
138       // rotating_size guaranteed to be >= 0 even though GetRotatingBufferSize() returns int
139       size_t rotating_size = ctx->GetRotatingBufferSize();
140       bool use_buffer_rotation = (rotating_size > 0);
141       size_t param_size = params->GetSize(use_buffer_rotation);
142       size_t param_count = (rotating_size / param_size) + 1;
143       constexpr size_t MB = 1024*1024;
144       if (use_buffer_rotation) {
145         TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
146             "Needed Size: ", param_size/MB, " MiB. ",
147             "Needed number of param copies: ", param_count);
148       }
149       TORCH_CHECK(param_count > 0);
150 
151       std::vector<ParamsT*> reusable_params(param_count);
152       for (size_t i = 0; i < param_count; i++) {
153         reusable_params[i] = params->DeepCopy(use_buffer_rotation);
154       }
155 
156       // for rotating buffer
157       size_t offset = 0;
158 
159       for (size_t i = 0; i < op_names_.size(); i++) {
160         auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
161 
162         if (do_numerics_check) {
163           ParamsT* numerical_params = params->DeepCopy(false);
164           auto status = candidate->Call(numerical_params);
165           if (status != OK) {
166             numerical_params->Delete();
167             TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
168             continue;
169           }
170           status = reference_params->NumericalCheck(numerical_params);
171           numerical_params->Delete();
172           if (status != OK) {
173             TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
174             continue;
175           }
176         }
177         else {
178           auto status = candidate->Call(reusable_params[0]);
179           if (status != OK) {
180             TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
181             continue;
182           }
183         }
184 
185         // collect a small profile
186         constexpr const int approx_num_iter = 3;
187         auto approx_duration = Profile(candidate, reusable_params, approx_num_iter, offset);
188         // bail if too slow
189         if (approx_duration > 2 * min_duration_ms) {
190           TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
191           continue;
192         }
193 
194         // for warmup does user set max duration, max iters, or both?
195         // warmup is allowed to be skipped by setting either iterations or duration to 0
196         double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
197         int max_warmup_iter = ctx->GetMaxWarmupIterations();
198         int warmup_iter = 1; // default
199         if (max_warmup_duration >= 0) {
200           int duration_iters = max_warmup_duration / approx_duration;
201           if (max_warmup_iter >= 0) {
202             warmup_iter = std::min(max_warmup_iter, duration_iters);
203           }
204           else {
205             warmup_iter = duration_iters;
206           }
207         }
208         else if (max_warmup_iter >= 0) {
209           warmup_iter = max_warmup_iter;
210         }
211 
212         // for tuning does user set max duration, max iters, or both?
213         double max_tuning_duration = ctx->GetMaxTuningDurationMs();
214         int max_tuning_iter = ctx->GetMaxTuningIterations();
215         int tuning_iter = 100; // default
216         if (max_tuning_duration > 0) {
217           int duration_iters = max_tuning_duration / approx_duration;
218           if (max_tuning_iter > 0) {
219             tuning_iter = std::min(max_tuning_iter, duration_iters);
220           }
221           else {
222             tuning_iter = duration_iters;
223           }
224         }
225         else if (max_tuning_iter > 0) {
226           tuning_iter = max_tuning_iter;
227         }
228         // tuning must run at least 1 iteration
229         tuning_iter = std::max(1, tuning_iter);
230 
231         // do the full warmup followed by tuning
232         double warmup_ms = warmup_iter * approx_duration;
233         double tuning_ms = tuning_iter * approx_duration;
234         TUNABLE_LOG3("├──tuning using "
235             "warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
236             "and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
237             "instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
238         TUNABLE_LOG3("├──offset at ", offset);
239         WarmUp(candidate, reusable_params, warmup_iter, offset);
240         auto duration_ms = Profile(candidate, reusable_params, tuning_iter, offset);
241         if (duration_ms < min_duration_ms) {
242           TUNABLE_LOG3("├──found better instance id=", i, ". " , duration_ms, "ms. ", op_names_[i]);
243           min_duration_ms = duration_ms;
244           id_name = op_names_[i];
245         }
246       }
247 
248       for (size_t i = 0; i < reusable_params.size(); i++) {
249         reusable_params[i]->Delete();
250       }
251       if (reference_params) {
252         reference_params->Delete();
253       }
254 
255       TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
256       return ResultEntry(id_name, min_duration_ms);
257     }
258 
259   private:
CreateSignature()260     std::string CreateSignature() {
261 #ifndef _WIN32
262       const auto* name = typeid(*this).name();
263       char buf[256];
264       size_t buf_len = 256;
265       abi::__cxa_demangle(name, buf, &buf_len, nullptr);
266       buf[255] = '\0';
267       return buf;
268 #else
269       return typeid(*this).name();
270 #endif
271     }
272 
273     mutable c10::once_flag signature_init_once_;
274     std::string signature_;
275 
276     std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
277     std::vector<std::string> op_names_;
278 };
279 
280 struct OpParams {
OpParamsOpParams281   OpParams() {}
282   virtual ~OpParams() = default;
283   virtual std::string Signature() const = 0;
284 };
285 
286 } // namespace at::cuda::tunable
287