1 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/add_if_then_else.h>
6 #include <torch/csrc/jit/passes/bailout_graph.h>
7 #include <torch/csrc/jit/passes/batch_mm.h>
8 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
9 #include <torch/csrc/jit/passes/check_strict_fusion.h>
10 #include <torch/csrc/jit/passes/clear_profiling.h>
11 #include <torch/csrc/jit/passes/clear_undefinedness.h>
12 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
13 #include <torch/csrc/jit/passes/constant_pooling.h>
14 #include <torch/csrc/jit/passes/constant_propagation.h>
15 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/decompose_ops.h>
18 #include <torch/csrc/jit/passes/graph_fuser.h>
19 #include <torch/csrc/jit/passes/guard_elimination.h>
20 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
21 #include <torch/csrc/jit/passes/inliner.h>
22 #include <torch/csrc/jit/passes/inplace_check.h>
23 #include <torch/csrc/jit/passes/insert_guards.h>
24 #include <torch/csrc/jit/passes/loop_unrolling.h>
25 #include <torch/csrc/jit/passes/lower_grad_of.h>
26 #include <torch/csrc/jit/passes/lower_tuples.h>
27 #include <torch/csrc/jit/passes/pass_manager.h>
28 #include <torch/csrc/jit/passes/peephole.h>
29 #include <torch/csrc/jit/passes/remove_expands.h>
30 #include <torch/csrc/jit/passes/remove_mutation.h>
31 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
32 #include <torch/csrc/jit/passes/shape_analysis.h>
33 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
34 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
35 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
36 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
37 #include <chrono>
38 #include <mutex>
39 #include <optional>
40
41 C10_DEFINE_bool(
42 torch_jit_enable_new_executor,
43 true,
44 "If this flag is set to false TorchScript will be using the legacy/original executor");
45
46 C10_DEFINE_bool(
47 torch_jit_disable_warning_prints,
48 false,
49 "Disables warning.warn prints in TorchScript graph");
50
51 C10_DEFINE_bool(
52 torch_jit_static_then_dynamic,
53 false,
54 "fuse on two static compilations then 10 dynamic");
55
56 C10_DEFINE_bool(
57 torch_jit_always_dynamic,
58 false,
59 "fuse on 12 dynamic compilations");
60
61 C10_DEFINE_bool(
62 torch_jit_release_profiling_graph_after_optimization,
63 false,
64 "After getOptimizedPlanFor release the optimization record for reduction of memory in inference. This is aggressive memory saving, and please be cautious!");
65
66 C10_DEFINE_int32(
67 torch_jit_release_profiling_graph_delay_in_seconds,
68 60,
69 "How long to wait before releasing the profiling graph after optimizaiton is done. Only used if torch_jit_release_profiling_graph_after_optimization is set to true.");
70
71 constexpr size_t kDefaultNumProfiledRuns = 1;
72 constexpr size_t kDefaultBailoutDepth = 20;
73
74 C10_DEFINE_int64(
75 torch_jit_num_profiled_runs,
76 kDefaultNumProfiledRuns,
77 "Number of profiling runs");
78 C10_DEFINE_int64(
79 torch_jit_bailout_depth,
80 kDefaultBailoutDepth,
81 "Number of re-specializations");
82
83 namespace torch::jit {
84
85 namespace {
getNowInSecs()86 int32_t getNowInSecs() {
87 auto currentTimePoint = std::chrono::system_clock::now();
88 auto durationSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(
89 currentTimePoint.time_since_epoch());
90 return static_cast<int32_t>(durationSinceEpoch.count());
91 }
92 } // namespace
93
94 #if defined(C10_MOBILE)
95 static std::atomic<bool> executor_mode{true};
96 static std::atomic<bool> profiling_mode{false};
97 #else
98 static std::atomic<bool> executor_mode{true};
99 static std::atomic<bool> profiling_mode{true};
100 #endif
101
102 static std::mutex fusion_strategy_lock;
103
getInitialStrategy()104 static FusionStrategy getInitialStrategy() {
105 if (FLAGS_torch_jit_always_dynamic) {
106 return {{FusionBehavior::DYNAMIC, 12}};
107 }
108 FusionStrategy mixed = {
109 {FusionBehavior::STATIC, 2}, {FusionBehavior::DYNAMIC, 10}};
110 if (FLAGS_torch_jit_static_then_dynamic) {
111 return mixed;
112 }
113 // TODO remove ifdef
114 #ifdef FBCODE_CAFFE2
115 return {{FusionBehavior::STATIC, 20}};
116 #endif
117 return mixed;
118 }
119
120 // defer initial value so that we can load in gflags
121 static std::optional<FusionStrategy> fusion_strategy = std::nullopt;
122
getFusionStrategy()123 FusionStrategy getFusionStrategy() {
124 std::lock_guard<std::mutex> guard(fusion_strategy_lock);
125 if (fusion_strategy == std::nullopt) {
126 fusion_strategy = getInitialStrategy();
127 }
128 return *fusion_strategy;
129 }
130
setFusionStrategy(FusionStrategy & strategy)131 FusionStrategy setFusionStrategy(FusionStrategy& strategy) {
132 std::lock_guard<std::mutex> guard(fusion_strategy_lock);
133 if (fusion_strategy == std::nullopt) {
134 fusion_strategy = getInitialStrategy();
135 }
136 FusionStrategy old_strategy = *fusion_strategy;
137 fusion_strategy = strategy;
138 return old_strategy;
139 }
140
141 static std::atomic<size_t> num_profiled_runs{kDefaultNumProfiledRuns};
142
getProfilingMode()143 std::atomic<bool>& getProfilingMode() {
144 return profiling_mode;
145 }
146
getExecutorMode()147 std::atomic<bool>& getExecutorMode() {
148 return executor_mode;
149 }
150
getNumProfiledRuns()151 std::atomic<size_t>& getNumProfiledRuns() {
152 // Initialize num_profiled_runs from command-line flag.
153 static const size_t init = []() {
154 return num_profiled_runs = FLAGS_torch_jit_num_profiled_runs;
155 }();
156 (void)init; // Silence clang-tidy.
157 return num_profiled_runs;
158 }
159
getBailoutDepth()160 size_t getBailoutDepth() {
161 // Initialize bailout_depth from command-line flag.
162 size_t depth = 0;
163 for (const auto& pair : getFusionStrategy()) {
164 depth += pair.second;
165 }
166 return depth;
167 }
168
needsGradientInProfilingMode(Block * b)169 static bool needsGradientInProfilingMode(Block* b) {
170 for (auto n : b->nodes()) {
171 if (n->kind() == prim::BailOut) {
172 auto ptt = n->output()->type()->expect<TensorType>();
173 if (ptt->requiresGrad() && *ptt->requiresGrad()) {
174 return true;
175 }
176 }
177 if (n->kind() == prim::profile) {
178 auto type = n->ty(attr::profiled_type)->expect<TensorType>();
179 if (type->requiresGrad() && *type->requiresGrad()) {
180 return true;
181 }
182 }
183
184 for (auto ib : n->blocks()) {
185 if (needsGradientInProfilingMode(ib)) {
186 return true;
187 }
188 }
189 }
190 return false;
191 }
192
193 // `prim::RequiresGradCheck` guarantees that requires_grad properties
194 // of input tensors will match the profiled, otherwise a fallback path
195 // will be triggered. This allow us to prune off gradients in backward
196 // graph for inputs that don't need gradients. We transfer requires_grad
197 // properties from inputs to the `prim::DifferentiableGraph` onto inputs to the
198 // differentiable graph. Autodiff will inspect these properties and prune
199 // off gradients that aren't required
200 // `requires_grad` properties from `dnode->outputs()` will also be transferred
setRequiresGradOnDiffGraph(Node * dnode)201 static C10_UNUSED void setRequiresGradOnDiffGraph(Node* dnode) {
202 auto gi = dnode->g(attr::Subgraph)->inputs();
203 for (size_t i = 0; i < dnode->inputs().size(); i++) {
204 if (auto ty = dnode->input(i)->type()->cast<TensorType>()) {
205 auto gi_ty = gi[i]->type()->expect<TensorType>();
206 gi[i]->setType(gi_ty->withRequiresGrad(ty->requires_grad()));
207 GRAPH_DEBUG(
208 "Setting ",
209 *gi_ty->withRequiresGrad(ty->requires_grad()),
210 " on ",
211 gi[i],
212 " ",
213 gi[i]->debugName());
214 }
215 }
216
217 // We also need to put requires_grad on outputs within subgraph, so autodiff
218 // can set df_input_vjps and DifferentiableGraphOp can set `requires_grad=`
219 // properly
220 auto go = dnode->g(attr::Subgraph)->outputs();
221 auto set_requires_grad = [](const TensorTypePtr& t, Value* val) -> bool {
222 if (t && t->requiresGrad().has_value()) {
223 GRAPH_DEBUG("setting type ", *t);
224 val->setType(t);
225 return true;
226 }
227 return false;
228 };
229
230 for (const auto i : c10::irange(go.size())) {
231 auto ty = go[i]->type()->cast<TensorType>();
232 if (ty) {
233 auto n = go[i]->node();
234 auto dno = dnode->outputs().at(i);
235 for (auto dno_use : dno->uses()) {
236 GRAPH_DEBUG("found user of ", i, " as ", *dno_use.user);
237 if (n->kind() == prim::profile) {
238 if (set_requires_grad(
239 n->ty(attr::profiled_type)->expect<TensorType>(), go[i])) {
240 break;
241 }
242 } else if (dno_use.user->kind() == prim::profile) {
243 if (set_requires_grad(
244 dno_use.user->ty(attr::profiled_type)->expect<TensorType>(),
245 go[i])) {
246 break;
247 }
248 } else if (dno_use.user->kind() == prim::DifferentiableGraph) {
249 Value* o =
250 dno_use.user->g(attr::Subgraph)->inputs().at(dno_use.offset);
251 // Is it safe to not check other uses, because we are inside a
252 // DifferentiableGraph?
253 auto nn = o->uses().at(0).user;
254 if (nn->kind() == prim::profile) {
255 if (set_requires_grad(
256 nn->ty(attr::profiled_type)->expect<TensorType>(), go[i])) {
257 break;
258 }
259 }
260 }
261 }
262 }
263 }
264 }
265
guardDifferentiableGraph(Node * dnode)266 static bool guardDifferentiableGraph(Node* dnode) {
267 auto gi = dnode->g(attr::Subgraph)->inputs();
268 bool all_inputs_seen = true;
269 for (const auto i : c10::irange(gi.size())) {
270 auto ty = gi[i]->type()->cast<TensorType>();
271 if (ty) {
272 auto n = gi[i]->uses().at(0).user;
273 auto dni = dnode->inputs().at(i);
274 GRAPH_DEBUG("found first user of ", i, " as ", *n);
275 if (n->kind() == prim::profile) {
276 GRAPH_DEBUG(
277 "setting input ", i, " to type ", *n->ty(attr::profiled_type));
278 dni->setType(n->ty(attr::profiled_type));
279 } else if (dni->node()->kind() == prim::DifferentiableGraph) {
280 // The profiling node might have been absorbed in a preceding
281 // differentiable graph and thus not (not ideal for fusing either),
282 // see TestAutodiffSubgraphSlicing.test_does_not_create_cycles.
283 // Alternatives to this special casing could be specializing the types
284 // before autodiff or duplicating profile nodes for autodiff outputs
285 // but that should be done while creating subgraphs and would be
286 // a mess.
287 // XXX TODO: revisit the alternatives
288 Value* o = dni->node()->g(attr::Subgraph)->outputs().at(dni->offset());
289 if (o->node()->kind() == prim::profile) {
290 dni->setType(o->node()->ty(attr::profiled_type));
291 }
292 }
293
294 // Propagate the requires_grad property to inputs
295 // A RequiresGrad check gets added (insertTypeGuard, below)
296 // so requires_grad is guaranteed to match for the inputs;
297 // but other properties are not guaranteed to match
298 auto requires_grad = dni->type()->expectRef<TensorType>().requiresGrad();
299 gi[i]->setType(ty->withRequiresGrad(requires_grad));
300
301 // we check if the optional is defined
302 all_inputs_seen &= (dni->type()->cast<TensorType>() != TensorType::get());
303 }
304 }
305 if (all_inputs_seen) {
306 // we may have seen both true and false for requires_grad. In this case
307 // we guard with true here and the other case is in the fallback. This
308 // will give us trouble when we get "alternating patterns" of gradients
309 // of two inputs, but so it is. An alternative could be to look into
310 // the individual requires_grad seen in the profiling record.
311 insertTypeGuard(
312 dnode,
313 [](const TensorTypePtr& t) {
314 return TensorType::get()->withRequiresGrad(
315 t->requiresGrad().value_or(true));
316 },
317 prim::RequiresGradCheck);
318 return true;
319 } else {
320 // we inline the differentiable graph as a fallback
321 // ideally we would set this up for re-profiling
322 UpdateDifferentiableGraphRequiresGrad(
323 dnode->g(attr::Subgraph), std::nullopt);
324 SubgraphUtils::unmergeSubgraph(dnode);
325 return false;
326 }
327 }
328
runNooptPassPipeline(std::shared_ptr<Graph> & graph)329 void runNooptPassPipeline(std::shared_ptr<Graph>& graph) {
330 GRAPH_DEBUG("Before Inliner (beginning of runNooptPassPipeline)\n", *graph);
331 Inline(*graph);
332 GRAPH_DEBUG("After Inline, Before NoGrad\n", *graph);
333 LowerGradOf(*graph);
334 GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph);
335 RemoveExpands(graph);
336 GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph);
337 CanonicalizeOps(graph);
338 GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph);
339 EliminateDeadCode(graph);
340 GRAPH_DEBUG(
341 "After EliminateDeadCode (end of runNooptPassPipeline)\n", *graph);
342 }
343
runPreAutodiffPassPipeline(std::shared_ptr<Graph> & graph)344 static void runPreAutodiffPassPipeline(std::shared_ptr<Graph>& graph) {
345 GRAPH_DEBUG(
346 "Before InsertGuards (beginning of runPreAutodiffPassPipeline)\n",
347 *graph);
348
349 LowerGradOf(*graph);
350 GRAPH_DEBUG("After LowerGradOf, before specializeAutogradZero\n", *graph);
351
352 specializeAutogradZero(graph);
353 GRAPH_DEBUG("After specializeAutogradZero\n", *graph);
354 // runRequiredPasses
355 {
356 RemoveExpands(graph);
357 GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph);
358 CanonicalizeOps(graph);
359 GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph);
360 EliminateDeadCode(graph);
361 GRAPH_DEBUG("After EliminateDeadCode", *graph);
362 }
363 PeepholeOptimize(graph);
364 GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
365 ConstantPropagation(graph);
366
367 // runOptimization:
368 {
369 EliminateDeadCode(graph);
370 GRAPH_DEBUG(
371 "After EliminateDeadCode, before EliminateCommonSubexpression\n",
372 *graph);
373 EliminateCommonSubexpression(graph);
374 GRAPH_DEBUG(
375 "After EliminateCommonSubexpression, before PeepholeOptimize\n",
376 *graph);
377
378 PeepholeOptimize(graph);
379 GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
380 ConstantPropagation(graph);
381 GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);
382 ConstantPooling(graph);
383 GRAPH_DEBUG("After ConstantPooling, before UnrollLoops\n", *graph);
384
385 UnrollLoops(graph);
386 GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph);
387 // run again with unrolled loops
388 RemoveListMutation(graph);
389 GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph);
390 PeepholeOptimize(graph);
391 GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
392 ConstantPropagation(graph);
393 GRAPH_DEBUG(
394 "After ConstantPropagation, before EliminateCommonSubexpression\n",
395 *graph);
396
397 EliminateCommonSubexpression(graph);
398 GRAPH_DEBUG(
399 "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
400 CheckInplace(graph);
401 }
402 GRAPH_DEBUG(
403 "After CheckInplace (end of runPreAutodiffPassPipeline)\n", *graph);
404 }
405
getCurrentBehavior(size_t remaining_depth)406 FusionBehavior ProfilingGraphExecutorImpl::getCurrentBehavior(
407 size_t remaining_depth) {
408 size_t curr_depth = 0;
409 for (int i = static_cast<int>(fusion_strategy_.size()) - 1; i >= 0; i--) {
410 curr_depth += fusion_strategy_[i].second;
411 if (remaining_depth <= curr_depth) {
412 return fusion_strategy_[i].first;
413 }
414 }
415 // should never get here
416 TORCH_WARN("Strategy changed mid-invocation, NYI");
417 return FusionBehavior::STATIC;
418 }
419
runNoGradOptimizations(std::shared_ptr<Graph> & graph,size_t remaining_bailout_depth)420 void ProfilingGraphExecutorImpl::runNoGradOptimizations(
421 std::shared_ptr<Graph>& graph,
422 size_t remaining_bailout_depth) {
423 GRAPH_DEBUG(
424 "After customPostPasses (beginning of runNoGradOptimizations)\n", *graph);
425 // runNondiffOptimization
426 {
427 // Run custom passes that different backends can register.
428 for (const auto& passPair : getCustomPrePasses()) {
429 passPair.first(graph);
430 }
431 GRAPH_DEBUG("After customPrePasses, before LowerSimpleTuples\n", *graph);
432
433 // TupleConstruct / TupleUnpack pairs can still be present at this point
434 // and must be removed for fusion.
435 LowerSimpleTuples(graph);
436 GRAPH_DEBUG("After LowerSimpleTuples\n", *graph);
437
438 if (tensorExprFuserEnabled()) {
439 // Remove prim::profile nodes and embed the profile info directly in the
440 // IR in value types. We're doing such transformation as optimizations
441 // that try to merge/fuse nodes in the graph (e.g. BatchMM and GraphFuser)
442 // work worse in the presence of intermittent prim::profile nodes.
443 // Optimizations relying on the type info are also responsible for
444 // inserting proper type checks. Once we're done with these optimizations
445 // we will wipe the tensor type information from the IR, so that it's not
446 // accidentally used by any other pass.
447 RemoveProfileNodesAndSpecializeTypes(graph);
448 GRAPH_DEBUG(
449 "After RemoveProfileNodesAndSpecializeTypes, before BatchMM\n",
450 *graph);
451 // Rewrite subgraphs with many MMs into expressions that batch them.
452 BatchMM(graph);
453 GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
454 auto min_size = getFusionGroupInlining() ? 2 : 1;
455 bool dyn_shapes = getCurrentBehavior(remaining_bailout_depth) ==
456 FusionBehavior::DYNAMIC;
457 FuseTensorExprs(graph, min_size, /* composed op*/ false, dyn_shapes);
458 GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph);
459 } else {
460 // Rewrite subgraphs with many MMs into expressions that batch them.
461 BatchMM(graph);
462 GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
463
464 FuseGraph(graph, true);
465 GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph);
466 }
467
468 // Run custom post-fusion passes
469 for (const auto& passPair : getCustomPostPasses()) {
470 passPair.first(graph);
471 }
472 GRAPH_DEBUG(
473 "After customPostPasses, before RemoveTensorTypeSpecializations \n",
474 *graph);
475 RemoveTensorTypeSpecializations(graph);
476 GRAPH_DEBUG("After RemoveTensorTypeSpecializations\n", *graph);
477 }
478 GRAPH_DEBUG("End of runNoGradOptimizations\n");
479 }
480
runProfilingOptimizations(std::shared_ptr<Graph> & copy,size_t remaining_bailout_depth)481 void ProfilingGraphExecutorImpl::runProfilingOptimizations(
482 std::shared_ptr<Graph>& copy,
483 size_t remaining_bailout_depth) {
484 GRAPH_DEBUG("Before runProfilingOptimizations:\n", *copy);
485 if (!getGraphExecutorOptimize()) {
486 runNooptPassPipeline(copy);
487 return;
488 }
489
490 runPreAutodiffPassPipeline(copy);
491
492 if (needsGradientInProfilingMode(copy->block())) {
493 auto diff_nodes = CreateAutodiffSubgraphs(
494 copy,
495 getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
496 GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *copy);
497 size_t idx = 0;
498 for (Node* dnode : diff_nodes) {
499 GRAPH_DEBUG("Optimizing diff node ", idx, " in ", *copy);
500 if (!guardDifferentiableGraph(dnode)) {
501 // if we cannot guard (because of inputs without profiling information),
502 // we re-inline the subgraph and remove the differentiable node
503 GRAPH_DEBUG("Could not guardDifferentiableGraph ", idx, " in ", *copy);
504 idx++;
505 continue;
506 }
507 GRAPH_DEBUG("After guardDifferentiableGraph:\n", *copy);
508 auto diff_graph = std::move(dnode->g(attr::Subgraph));
509 Gradient gradient = differentiate(diff_graph);
510 RemoveTensorTypeSpecializations(gradient.f);
511 ProfilingRecord::removeProfilingNodes(gradient.f->block());
512 GRAPH_DEBUG("Forward graph:\n", *(gradient.f));
513 GRAPH_DEBUG("Backward graph:\n", *(gradient.df));
514 // just like inside autograd.Functions, the forward of a differentiable
515 // graph is essentially in a torch.no_grad context.
516 UpdateDifferentiableGraphRequiresGrad(gradient.f, false);
517 GRAPH_DEBUG("After UpdateDifferentiableGraphRequiresGrad ", *gradient.f);
518 // replaces fallback graphs inserted by TE Fuser
519 replaceFallbackGraphWithFallbackFunction(gradient.f->block());
520 packGradient(gradient, dnode);
521 GRAPH_DEBUG("Finished optimizing diff node ", idx++);
522 }
523 InlineAutodiffSubgraphs(
524 copy,
525 getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
526 replaceFallbackGraphWithFallbackFunction(copy->block());
527 ProfilingRecord::removeProfilingNodes(copy->block());
528 GRAPH_DEBUG(
529 "After InlineAutodiffSubgraphs and Removing Profiling Nodes\n", *copy);
530 } else {
531 runNoGradOptimizations(copy, remaining_bailout_depth);
532 }
533 EliminateDeadCode(copy);
534 GRAPH_DEBUG("After runProfilingOptimizations:\n", *copy);
535 }
536
runProfilingInsensitiveOptimizations(std::shared_ptr<Graph> & graph)537 void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations(
538 std::shared_ptr<Graph>& graph) {
539 GRAPH_DEBUG(
540 "Before inlining (beginning of runProfilingInsensitiveOptimizations)\n",
541 *graph);
542 // TODO: maybe this can go later in pipeline / directly in autodiff forward
543 // creation
544 if (getGraphExecutorOptimize()) {
545 Inline(*graph);
546 }
547 GRAPH_DEBUG("After inlining, before ClearProfilingInformation\n", *graph);
548 ClearProfilingInformation(graph);
549 GRAPH_DEBUG("After ClearProfilingInformation, before LowerGradOf\n", *graph);
550 LowerGradOf(*graph);
551 GRAPH_DEBUG("After LowerGradOf, before ClearUndefinedness\n", *graph);
552 // clear any residual undefinedness
553 // as double backward graph inputs'
554 // may carry over undefinedness
555 // from profiled backward graphs
556 ClearUndefinedness(graph);
557 // runRequiredPasses
558 {
559 GRAPH_DEBUG("After ClearUndefinedness, before RemoveExpands\n", *graph);
560 RemoveExpands(graph);
561 GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph);
562 CanonicalizeOps(graph);
563 GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph);
564 EliminateDeadCode(graph);
565 }
566 if (!getGraphExecutorOptimize()) {
567 GRAPH_DEBUG(
568 "After EliminateDeadCode (end of runProfilingInsensitiveOptimizations)\n",
569 *graph);
570 return;
571 }
572
573 GRAPH_DEBUG("After EliminateDeadCode, before DecomposeOps\n", *graph);
574 DecomposeOps(graph);
575 GRAPH_DEBUG("After DecomposeOps, before ConstantPropagation\n", *graph);
576 ConstantPropagation(graph);
577 GRAPH_DEBUG("After ConstantPropagation, before EliminateDeadCode\n", *graph);
578 EliminateDeadCode(graph);
579 GRAPH_DEBUG(
580 "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);
581 EliminateCommonSubexpression(graph);
582 GRAPH_DEBUG(
583 "After EliminateCommonSubexpression, before ConstantPooling\n", *graph);
584 ConstantPooling(graph);
585 GRAPH_DEBUG("After ConstantPooling, before PeepholeOptimize\n", *graph);
586 PeepholeOptimize(graph);
587 GRAPH_DEBUG("After PeepholeOptimize, before EliminateDeadCode\n", *graph);
588 EliminateDeadCode(graph);
589 GRAPH_DEBUG("After EliminateDeadCode, before LowerSimpleTuples\n", *graph);
590 LowerSimpleTuples(graph);
591 GRAPH_DEBUG("After LowerSimpleTuples, before CheckInplace\n", *graph);
592 CheckInplace(graph);
593 GRAPH_DEBUG(
594 "After CheckInplace (end of runProfilingInsensitiveOptimizations)\n",
595 *graph);
596 }
597
ProfilingGraphExecutorImpl(const std::shared_ptr<Graph> & graph,std::string function_name)598 ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
599 const std::shared_ptr<Graph>& graph,
600 std::string function_name)
601 : GraphExecutorImplBase(graph, std::move(function_name)) {
602 fusion_strategy_ = getFusionStrategy();
603 }
604
getInstantiatedBailoutDepth()605 size_t ProfilingGraphExecutorImpl::getInstantiatedBailoutDepth() {
606 // Initialize bailout_depth from command-line flag.
607 size_t depth = 0;
608 for (const auto& pair : fusion_strategy_) {
609 depth += pair.second;
610 }
611 return depth;
612 }
613
getOptimizedPlanFor(Stack & stack,std::optional<size_t> remaining_bailout_depth)614 const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
615 Stack& stack,
616 std::optional<size_t> remaining_bailout_depth) {
617 GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this);
618
619 // TODO: instantiate simple executor when getProfilingMode() is false
620 // no opt mode
621 if (!getGraphExecutorOptimize() || !getProfilingMode()) {
622 if (!fallback_plan_) {
623 auto copy = graph->copy();
624 GRAPH_DEBUG(
625 "Before LowerGradOf (beginning of runNooptPassPipeline)\n", *graph);
626 LowerGradOf(*copy);
627 GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph);
628 RemoveExpands(copy);
629 fallback_plan_ = ExecutionPlan(copy, function_name_);
630 GRAPH_DUMP("NoOpt Graph: ", copy);
631 }
632 return *fallback_plan_;
633 }
634
635 // if tensorExprFuserEnabled() returns true we need to persist the very first
636 // time ProfilingGraphExecutorImpl is called, so we can update it correctly
637 // for fallback functions in ProfilingGraphExecutorImpl Else,
638 // getPlanFor(remaining_bailout_depth) is corrected and persisted by the Code
639 // object in interpreter.
640 if (!remaining_bailout_depth_.has_value() || !tensorExprFuserEnabled()) {
641 if (remaining_bailout_depth.has_value()) {
642 remaining_bailout_depth_ = *remaining_bailout_depth;
643 } else {
644 remaining_bailout_depth_ = getInstantiatedBailoutDepth();
645 }
646 }
647
648 // simple executor
649 if (*remaining_bailout_depth_ == 0) {
650 auto copy = graph->copy();
651 runProfilingInsensitiveOptimizations(copy);
652 GRAPH_DUMP("Optimized SimpleExecutor Graph: ", copy);
653 optimized_plan_ = ExecutionPlan(copy, function_name_);
654 time_optimized_plan_created_ = getNowInSecs();
655 return *optimized_plan_;
656 }
657
658 bool profiling_record_created_in_this_call = false;
659 // if a profiling graph hasn't been created yet
660 if (!pr_) {
661 auto copy = graph->copy();
662 runProfilingInsensitiveOptimizations(copy);
663 pr_ = ProfilingRecord::instrumentGraph(copy);
664 profiling_record_created_in_this_call = true;
665 // `InsertProfileNodesForSpecializeAutogradZero` profiles a definition vs a
666 // use and it doesn't expect any profile nodes between a graph input and its
667 // consumer, `aten::_grad_sum_to_size`. This means we need to run it first,
668 // before any other pass that could insert `prim::iprofile_value` node on
669 // `aten::_grad_sum_to_size` input.
670 InsertProfileNodesForSpecializeAutogradZero(pr_.get());
671 GRAPH_DUMP("Profiled Graph: ", pr_->graph());
672 profiling_plan_ = ExecutionPlan(pr_->graph(), function_name_);
673 // fall-through
674 }
675
676 // profile until a graph is ready
677 if (!pr_->ready()) {
678 return *profiling_plan_;
679 }
680
681 auto copy = pr_->graph()->copy();
682 ProfilingRecord::removeProfileCounter(copy->block());
683 runProfilingOptimizations(copy, *remaining_bailout_depth_);
684 // replaces a fallback graph inserted by
685 // specialize_autogradzero if one exists
686 replaceFallbackGraphWithFallbackFunction(copy->block());
687 runFinalOptimizations(copy);
688 CheckStrictFusion(copy);
689 GRAPH_DUMP("Optimized Graph: ", copy);
690 optimized_plan_ = ExecutionPlan(copy, function_name_);
691 time_optimized_plan_created_ = getNowInSecs();
692 // If the profiled graph was created in this call, then we can release it
693 // right.
694 if (FLAGS_torch_jit_release_profiling_graph_after_optimization &&
695 profiling_record_created_in_this_call) {
696 clearTheGraphCompilationIntermediateGraphs();
697 }
698 return *optimized_plan_;
699 }
700
getPlanFor(Stack & stack,std::optional<size_t> remaining_bailout_depth)701 const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor(
702 Stack& stack,
703 std::optional<size_t> remaining_bailout_depth) {
704 std::lock_guard<std::mutex> lock(compile_mutex);
705
706 // IMPORTANT: This is a hot path of calling a torchscript function. Try not to
707 // add any code above this.
708 if (optimized_plan_) {
709 if (FLAGS_torch_jit_release_profiling_graph_after_optimization &&
710 !is_graph_extra_memory_released_) {
711 int32_t now = getNowInSecs();
712 if ((now - time_optimized_plan_created_) >
713 FLAGS_torch_jit_release_profiling_graph_delay_in_seconds) {
714 clearTheGraphCompilationIntermediateGraphs();
715 }
716 }
717 return *optimized_plan_;
718 }
719 // if depth is not set, use
720 return getOptimizedPlanFor(stack, remaining_bailout_depth);
721 }
722
getDebugState()723 GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
724 GraphExecutorState state;
725 TORCH_INTERNAL_ASSERT(optimized_plan_);
726 auto opt_plan = *optimized_plan_;
727 state.execution_plans.emplace(ArgumentSpec{0, 0}, opt_plan);
728 return state;
729 }
730
insertFallbackFunctionCall(Graph * graph,GraphFunction * func,ArrayRef<Value * > inputs)731 static Node* insertFallbackFunctionCall(
732 Graph* graph,
733 GraphFunction* func,
734 ArrayRef<Value*> inputs) {
735 auto tuple_type = func->graph()->return_node()->input(0)->type();
736 Value* fn_constant = graph->insertNode(graph->create(prim::Constant))
737 ->s_(attr::name, func->name())
738 ->i_(Symbol::attr("fallback"), 1)
739 ->output()
740 ->setType(FunctionType::create(func));
741 std::vector<Value*> func_call_inputs = {fn_constant};
742 func_call_inputs.insert(func_call_inputs.end(), inputs.begin(), inputs.end());
743 Value* result =
744 graph->insertNode(graph->create(prim::CallFunction, func_call_inputs))
745 ->output()
746 ->setType(tuple_type);
747
748 auto fun_unpack_tuple = graph->insertNode(graph->createTupleUnpack(result));
749 return fun_unpack_tuple;
750 }
751
createFallbackPathFunction(Block * b,const std::string & function_name)752 static GraphFunction* createFallbackPathFunction(
753 Block* b,
754 const std::string& function_name) {
755 auto value_map = [](Value* v) { return v; };
756 auto graph = std::make_shared<Graph>();
757 graph->block()->cloneFrom(b, value_map);
758
759 auto otypes = c10::fmap(
760 graph->return_node()->inputs(), [](Value* v) { return v->type(); });
761 // a GraphFunction call only have one output, so all the outputs
762 // need to be packed into a tuple
763 auto tuple_type = TupleType::create(otypes);
764 auto return_tuple = graph->createTuple(graph->return_node()->inputs());
765 graph->appendNode(return_tuple);
766 for (int i = static_cast<int>(graph->outputs().size()) - 1; i >= 0; i--) {
767 graph->eraseOutput(i);
768 }
769 graph->registerOutput(return_tuple->output());
770 return new GraphFunction(function_name, graph, nullptr);
771 }
772
replaceFallbackGraphWithFallbackFunction(Block * b)773 void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction(
774 Block* b) {
775 Stack s;
776 for (auto it = b->nodes().begin(); it != b->nodes().end();) {
777 if (it->kind() == prim::FallbackGraph) {
778 auto fallback_func = createFallbackPathFunction(
779 it->g(attr::Subgraph)->block(), "fallback_function");
780 TORCH_INTERNAL_ASSERT(*remaining_bailout_depth_ > 0);
781 GRAPH_DEBUG(
782 "getPlanFor for", getHeader(*it), " ", *remaining_bailout_depth_);
783 fallback_func->get_executor().getPlanFor(
784 s, *remaining_bailout_depth_ - 1);
785 fallback_functions_.emplace_back(fallback_func);
786 WithInsertPoint wip{*it};
787 auto function_call = insertFallbackFunctionCall(
788 b->owningGraph(), fallback_func, it->inputs());
789 for (const auto i : c10::irange(function_call->outputs().size())) {
790 it->output(i)->replaceAllUsesWith(function_call->output(i));
791 }
792 it.destroyCurrent();
793 } else {
794 for (Block* ib : it->blocks()) {
795 replaceFallbackGraphWithFallbackFunction(ib);
796 }
797 it++;
798 }
799 }
800 }
801
runFinalOptimizations(std::shared_ptr<Graph> & graph)802 void ProfilingGraphExecutorImpl::runFinalOptimizations(
803 std::shared_ptr<Graph>& graph) {
804 AddIfThenElseOp(graph);
805 }
806
debugFlushCompilationCache()807 void ProfilingGraphExecutorImpl::debugFlushCompilationCache() {
808 std::lock_guard<std::mutex> lock(compile_mutex);
809 pr_.reset();
810 fallback_plan_.reset();
811 profiling_plan_.reset();
812 optimized_plan_.reset();
813 // prevent memory leaks
814 fallback_functions_.clear();
815 remaining_bailout_depth_.reset();
816 // TODO - would be nice to have it initialized in subsequent use
817 fusion_strategy_ = getFusionStrategy();
818 time_optimized_plan_created_ = 0;
819 is_graph_extra_memory_released_ = false;
820 }
821
clearTheGraphCompilationIntermediateGraphs()822 void ProfilingGraphExecutorImpl::clearTheGraphCompilationIntermediateGraphs() {
823 is_graph_extra_memory_released_ = true;
824 profiling_plan_.reset();
825 fallback_plan_.reset();
826 graph.reset();
827 pr_.reset();
828 }
829
830 } // namespace torch::jit
831