xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/profiling_graph_executor_impl.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/add_if_then_else.h>
6 #include <torch/csrc/jit/passes/bailout_graph.h>
7 #include <torch/csrc/jit/passes/batch_mm.h>
8 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
9 #include <torch/csrc/jit/passes/check_strict_fusion.h>
10 #include <torch/csrc/jit/passes/clear_profiling.h>
11 #include <torch/csrc/jit/passes/clear_undefinedness.h>
12 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
13 #include <torch/csrc/jit/passes/constant_pooling.h>
14 #include <torch/csrc/jit/passes/constant_propagation.h>
15 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/decompose_ops.h>
18 #include <torch/csrc/jit/passes/graph_fuser.h>
19 #include <torch/csrc/jit/passes/guard_elimination.h>
20 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
21 #include <torch/csrc/jit/passes/inliner.h>
22 #include <torch/csrc/jit/passes/inplace_check.h>
23 #include <torch/csrc/jit/passes/insert_guards.h>
24 #include <torch/csrc/jit/passes/loop_unrolling.h>
25 #include <torch/csrc/jit/passes/lower_grad_of.h>
26 #include <torch/csrc/jit/passes/lower_tuples.h>
27 #include <torch/csrc/jit/passes/pass_manager.h>
28 #include <torch/csrc/jit/passes/peephole.h>
29 #include <torch/csrc/jit/passes/remove_expands.h>
30 #include <torch/csrc/jit/passes/remove_mutation.h>
31 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
32 #include <torch/csrc/jit/passes/shape_analysis.h>
33 #include <torch/csrc/jit/passes/specialize_autogradzero.h>
34 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
35 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
36 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
37 #include <chrono>
38 #include <mutex>
39 #include <optional>
40 
41 C10_DEFINE_bool(
42     torch_jit_enable_new_executor,
43     true,
44     "If this flag is set to false TorchScript will be using the legacy/original executor");
45 
46 C10_DEFINE_bool(
47     torch_jit_disable_warning_prints,
48     false,
49     "Disables warning.warn prints in TorchScript graph");
50 
51 C10_DEFINE_bool(
52     torch_jit_static_then_dynamic,
53     false,
54     "fuse on two static compilations then 10 dynamic");
55 
56 C10_DEFINE_bool(
57     torch_jit_always_dynamic,
58     false,
59     "fuse on 12 dynamic compilations");
60 
61 C10_DEFINE_bool(
62     torch_jit_release_profiling_graph_after_optimization,
63     false,
64     "After getOptimizedPlanFor release the optimization record for reduction of memory in inference. This is aggressive memory saving, and please be cautious!");
65 
66 C10_DEFINE_int32(
67     torch_jit_release_profiling_graph_delay_in_seconds,
68     60,
69     "How long to wait before releasing the profiling graph after optimizaiton is done. Only used if torch_jit_release_profiling_graph_after_optimization is set to true.");
70 
71 constexpr size_t kDefaultNumProfiledRuns = 1;
72 constexpr size_t kDefaultBailoutDepth = 20;
73 
74 C10_DEFINE_int64(
75     torch_jit_num_profiled_runs,
76     kDefaultNumProfiledRuns,
77     "Number of profiling runs");
78 C10_DEFINE_int64(
79     torch_jit_bailout_depth,
80     kDefaultBailoutDepth,
81     "Number of re-specializations");
82 
83 namespace torch::jit {
84 
85 namespace {
getNowInSecs()86 int32_t getNowInSecs() {
87   auto currentTimePoint = std::chrono::system_clock::now();
88   auto durationSinceEpoch = std::chrono::duration_cast<std::chrono::seconds>(
89       currentTimePoint.time_since_epoch());
90   return static_cast<int32_t>(durationSinceEpoch.count());
91 }
92 } // namespace
93 
94 #if defined(C10_MOBILE)
95 static std::atomic<bool> executor_mode{true};
96 static std::atomic<bool> profiling_mode{false};
97 #else
98 static std::atomic<bool> executor_mode{true};
99 static std::atomic<bool> profiling_mode{true};
100 #endif
101 
102 static std::mutex fusion_strategy_lock;
103 
getInitialStrategy()104 static FusionStrategy getInitialStrategy() {
105   if (FLAGS_torch_jit_always_dynamic) {
106     return {{FusionBehavior::DYNAMIC, 12}};
107   }
108   FusionStrategy mixed = {
109       {FusionBehavior::STATIC, 2}, {FusionBehavior::DYNAMIC, 10}};
110   if (FLAGS_torch_jit_static_then_dynamic) {
111     return mixed;
112   }
113 // TODO remove ifdef
114 #ifdef FBCODE_CAFFE2
115   return {{FusionBehavior::STATIC, 20}};
116 #endif
117   return mixed;
118 }
119 
120 // defer initial value so that we can load in gflags
121 static std::optional<FusionStrategy> fusion_strategy = std::nullopt;
122 
getFusionStrategy()123 FusionStrategy getFusionStrategy() {
124   std::lock_guard<std::mutex> guard(fusion_strategy_lock);
125   if (fusion_strategy == std::nullopt) {
126     fusion_strategy = getInitialStrategy();
127   }
128   return *fusion_strategy;
129 }
130 
setFusionStrategy(FusionStrategy & strategy)131 FusionStrategy setFusionStrategy(FusionStrategy& strategy) {
132   std::lock_guard<std::mutex> guard(fusion_strategy_lock);
133   if (fusion_strategy == std::nullopt) {
134     fusion_strategy = getInitialStrategy();
135   }
136   FusionStrategy old_strategy = *fusion_strategy;
137   fusion_strategy = strategy;
138   return old_strategy;
139 }
140 
141 static std::atomic<size_t> num_profiled_runs{kDefaultNumProfiledRuns};
142 
getProfilingMode()143 std::atomic<bool>& getProfilingMode() {
144   return profiling_mode;
145 }
146 
getExecutorMode()147 std::atomic<bool>& getExecutorMode() {
148   return executor_mode;
149 }
150 
getNumProfiledRuns()151 std::atomic<size_t>& getNumProfiledRuns() {
152   // Initialize num_profiled_runs from command-line flag.
153   static const size_t init = []() {
154     return num_profiled_runs = FLAGS_torch_jit_num_profiled_runs;
155   }();
156   (void)init; // Silence clang-tidy.
157   return num_profiled_runs;
158 }
159 
getBailoutDepth()160 size_t getBailoutDepth() {
161   // Initialize bailout_depth from command-line flag.
162   size_t depth = 0;
163   for (const auto& pair : getFusionStrategy()) {
164     depth += pair.second;
165   }
166   return depth;
167 }
168 
needsGradientInProfilingMode(Block * b)169 static bool needsGradientInProfilingMode(Block* b) {
170   for (auto n : b->nodes()) {
171     if (n->kind() == prim::BailOut) {
172       auto ptt = n->output()->type()->expect<TensorType>();
173       if (ptt->requiresGrad() && *ptt->requiresGrad()) {
174         return true;
175       }
176     }
177     if (n->kind() == prim::profile) {
178       auto type = n->ty(attr::profiled_type)->expect<TensorType>();
179       if (type->requiresGrad() && *type->requiresGrad()) {
180         return true;
181       }
182     }
183 
184     for (auto ib : n->blocks()) {
185       if (needsGradientInProfilingMode(ib)) {
186         return true;
187       }
188     }
189   }
190   return false;
191 }
192 
193 // `prim::RequiresGradCheck` guarantees that requires_grad properties
194 // of input tensors will match the profiled, otherwise a fallback path
195 // will be triggered. This allow us to prune off gradients in backward
196 // graph for inputs that don't need gradients. We transfer requires_grad
197 // properties from inputs to the `prim::DifferentiableGraph` onto inputs to the
198 // differentiable graph. Autodiff will inspect these properties and prune
199 // off gradients that aren't required
200 // `requires_grad` properties from `dnode->outputs()` will also be transferred
setRequiresGradOnDiffGraph(Node * dnode)201 static C10_UNUSED void setRequiresGradOnDiffGraph(Node* dnode) {
202   auto gi = dnode->g(attr::Subgraph)->inputs();
203   for (size_t i = 0; i < dnode->inputs().size(); i++) {
204     if (auto ty = dnode->input(i)->type()->cast<TensorType>()) {
205       auto gi_ty = gi[i]->type()->expect<TensorType>();
206       gi[i]->setType(gi_ty->withRequiresGrad(ty->requires_grad()));
207       GRAPH_DEBUG(
208           "Setting ",
209           *gi_ty->withRequiresGrad(ty->requires_grad()),
210           " on ",
211           gi[i],
212           " ",
213           gi[i]->debugName());
214     }
215   }
216 
217   // We also need to put requires_grad on outputs within subgraph, so autodiff
218   // can  set df_input_vjps and DifferentiableGraphOp can set `requires_grad=`
219   // properly
220   auto go = dnode->g(attr::Subgraph)->outputs();
221   auto set_requires_grad = [](const TensorTypePtr& t, Value* val) -> bool {
222     if (t && t->requiresGrad().has_value()) {
223       GRAPH_DEBUG("setting type ", *t);
224       val->setType(t);
225       return true;
226     }
227     return false;
228   };
229 
230   for (const auto i : c10::irange(go.size())) {
231     auto ty = go[i]->type()->cast<TensorType>();
232     if (ty) {
233       auto n = go[i]->node();
234       auto dno = dnode->outputs().at(i);
235       for (auto dno_use : dno->uses()) {
236         GRAPH_DEBUG("found user of ", i, " as ", *dno_use.user);
237         if (n->kind() == prim::profile) {
238           if (set_requires_grad(
239                   n->ty(attr::profiled_type)->expect<TensorType>(), go[i])) {
240             break;
241           }
242         } else if (dno_use.user->kind() == prim::profile) {
243           if (set_requires_grad(
244                   dno_use.user->ty(attr::profiled_type)->expect<TensorType>(),
245                   go[i])) {
246             break;
247           }
248         } else if (dno_use.user->kind() == prim::DifferentiableGraph) {
249           Value* o =
250               dno_use.user->g(attr::Subgraph)->inputs().at(dno_use.offset);
251           // Is it safe to not check other uses, because we are inside a
252           // DifferentiableGraph?
253           auto nn = o->uses().at(0).user;
254           if (nn->kind() == prim::profile) {
255             if (set_requires_grad(
256                     nn->ty(attr::profiled_type)->expect<TensorType>(), go[i])) {
257               break;
258             }
259           }
260         }
261       }
262     }
263   }
264 }
265 
guardDifferentiableGraph(Node * dnode)266 static bool guardDifferentiableGraph(Node* dnode) {
267   auto gi = dnode->g(attr::Subgraph)->inputs();
268   bool all_inputs_seen = true;
269   for (const auto i : c10::irange(gi.size())) {
270     auto ty = gi[i]->type()->cast<TensorType>();
271     if (ty) {
272       auto n = gi[i]->uses().at(0).user;
273       auto dni = dnode->inputs().at(i);
274       GRAPH_DEBUG("found first user of ", i, " as ", *n);
275       if (n->kind() == prim::profile) {
276         GRAPH_DEBUG(
277             "setting input ", i, " to type ", *n->ty(attr::profiled_type));
278         dni->setType(n->ty(attr::profiled_type));
279       } else if (dni->node()->kind() == prim::DifferentiableGraph) {
280         // The profiling node might have been absorbed in a preceding
281         // differentiable graph and thus not (not ideal for fusing either),
282         // see TestAutodiffSubgraphSlicing.test_does_not_create_cycles.
283         // Alternatives to this special casing could be specializing the types
284         // before autodiff or duplicating profile nodes for autodiff outputs
285         // but that should be done while creating subgraphs and would be
286         // a mess.
287         // XXX TODO: revisit the alternatives
288         Value* o = dni->node()->g(attr::Subgraph)->outputs().at(dni->offset());
289         if (o->node()->kind() == prim::profile) {
290           dni->setType(o->node()->ty(attr::profiled_type));
291         }
292       }
293 
294       // Propagate the requires_grad property to inputs
295       // A RequiresGrad check gets added (insertTypeGuard, below)
296       // so requires_grad is guaranteed to match for the inputs;
297       // but other properties are not guaranteed to match
298       auto requires_grad = dni->type()->expectRef<TensorType>().requiresGrad();
299       gi[i]->setType(ty->withRequiresGrad(requires_grad));
300 
301       // we check if the optional is defined
302       all_inputs_seen &= (dni->type()->cast<TensorType>() != TensorType::get());
303     }
304   }
305   if (all_inputs_seen) {
306     // we may have seen both true and false for requires_grad. In this case
307     // we guard with true here and the other case is in the fallback. This
308     // will give us trouble when we get "alternating patterns" of gradients
309     // of two inputs, but so it is. An alternative could be to look into
310     // the individual requires_grad seen in the profiling record.
311     insertTypeGuard(
312         dnode,
313         [](const TensorTypePtr& t) {
314           return TensorType::get()->withRequiresGrad(
315               t->requiresGrad().value_or(true));
316         },
317         prim::RequiresGradCheck);
318     return true;
319   } else {
320     // we inline the differentiable graph as a fallback
321     // ideally we would set this up for re-profiling
322     UpdateDifferentiableGraphRequiresGrad(
323         dnode->g(attr::Subgraph), std::nullopt);
324     SubgraphUtils::unmergeSubgraph(dnode);
325     return false;
326   }
327 }
328 
runNooptPassPipeline(std::shared_ptr<Graph> & graph)329 void runNooptPassPipeline(std::shared_ptr<Graph>& graph) {
330   GRAPH_DEBUG("Before Inliner (beginning of runNooptPassPipeline)\n", *graph);
331   Inline(*graph);
332   GRAPH_DEBUG("After Inline, Before NoGrad\n", *graph);
333   LowerGradOf(*graph);
334   GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph);
335   RemoveExpands(graph);
336   GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph);
337   CanonicalizeOps(graph);
338   GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph);
339   EliminateDeadCode(graph);
340   GRAPH_DEBUG(
341       "After EliminateDeadCode (end of runNooptPassPipeline)\n", *graph);
342 }
343 
runPreAutodiffPassPipeline(std::shared_ptr<Graph> & graph)344 static void runPreAutodiffPassPipeline(std::shared_ptr<Graph>& graph) {
345   GRAPH_DEBUG(
346       "Before InsertGuards (beginning of runPreAutodiffPassPipeline)\n",
347       *graph);
348 
349   LowerGradOf(*graph);
350   GRAPH_DEBUG("After LowerGradOf, before specializeAutogradZero\n", *graph);
351 
352   specializeAutogradZero(graph);
353   GRAPH_DEBUG("After specializeAutogradZero\n", *graph);
354   // runRequiredPasses
355   {
356     RemoveExpands(graph);
357     GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph);
358     CanonicalizeOps(graph);
359     GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph);
360     EliminateDeadCode(graph);
361     GRAPH_DEBUG("After EliminateDeadCode", *graph);
362   }
363   PeepholeOptimize(graph);
364   GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
365   ConstantPropagation(graph);
366 
367   // runOptimization:
368   {
369     EliminateDeadCode(graph);
370     GRAPH_DEBUG(
371         "After EliminateDeadCode, before EliminateCommonSubexpression\n",
372         *graph);
373     EliminateCommonSubexpression(graph);
374     GRAPH_DEBUG(
375         "After EliminateCommonSubexpression, before PeepholeOptimize\n",
376         *graph);
377 
378     PeepholeOptimize(graph);
379     GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
380     ConstantPropagation(graph);
381     GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);
382     ConstantPooling(graph);
383     GRAPH_DEBUG("After ConstantPooling, before UnrollLoops\n", *graph);
384 
385     UnrollLoops(graph);
386     GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph);
387     // run again with unrolled loops
388     RemoveListMutation(graph);
389     GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph);
390     PeepholeOptimize(graph);
391     GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
392     ConstantPropagation(graph);
393     GRAPH_DEBUG(
394         "After ConstantPropagation, before EliminateCommonSubexpression\n",
395         *graph);
396 
397     EliminateCommonSubexpression(graph);
398     GRAPH_DEBUG(
399         "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
400     CheckInplace(graph);
401   }
402   GRAPH_DEBUG(
403       "After CheckInplace (end of runPreAutodiffPassPipeline)\n", *graph);
404 }
405 
getCurrentBehavior(size_t remaining_depth)406 FusionBehavior ProfilingGraphExecutorImpl::getCurrentBehavior(
407     size_t remaining_depth) {
408   size_t curr_depth = 0;
409   for (int i = static_cast<int>(fusion_strategy_.size()) - 1; i >= 0; i--) {
410     curr_depth += fusion_strategy_[i].second;
411     if (remaining_depth <= curr_depth) {
412       return fusion_strategy_[i].first;
413     }
414   }
415   // should never get here
416   TORCH_WARN("Strategy changed mid-invocation, NYI");
417   return FusionBehavior::STATIC;
418 }
419 
runNoGradOptimizations(std::shared_ptr<Graph> & graph,size_t remaining_bailout_depth)420 void ProfilingGraphExecutorImpl::runNoGradOptimizations(
421     std::shared_ptr<Graph>& graph,
422     size_t remaining_bailout_depth) {
423   GRAPH_DEBUG(
424       "After customPostPasses (beginning of runNoGradOptimizations)\n", *graph);
425   // runNondiffOptimization
426   {
427     // Run custom passes that different backends can register.
428     for (const auto& passPair : getCustomPrePasses()) {
429       passPair.first(graph);
430     }
431     GRAPH_DEBUG("After customPrePasses, before LowerSimpleTuples\n", *graph);
432 
433     // TupleConstruct / TupleUnpack pairs can still be present at this point
434     // and must be removed for fusion.
435     LowerSimpleTuples(graph);
436     GRAPH_DEBUG("After LowerSimpleTuples\n", *graph);
437 
438     if (tensorExprFuserEnabled()) {
439       // Remove prim::profile nodes and embed the profile info directly in the
440       // IR in value types. We're doing such transformation as optimizations
441       // that try to merge/fuse nodes in the graph (e.g. BatchMM and GraphFuser)
442       // work worse in the presence of intermittent prim::profile nodes.
443       // Optimizations relying on the type info are also responsible for
444       // inserting proper type checks. Once we're done with these optimizations
445       // we will wipe the tensor type information from the IR, so that it's not
446       // accidentally used by any other pass.
447       RemoveProfileNodesAndSpecializeTypes(graph);
448       GRAPH_DEBUG(
449           "After RemoveProfileNodesAndSpecializeTypes, before BatchMM\n",
450           *graph);
451       // Rewrite subgraphs with many MMs into expressions that batch them.
452       BatchMM(graph);
453       GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
454       auto min_size = getFusionGroupInlining() ? 2 : 1;
455       bool dyn_shapes = getCurrentBehavior(remaining_bailout_depth) ==
456           FusionBehavior::DYNAMIC;
457       FuseTensorExprs(graph, min_size, /* composed op*/ false, dyn_shapes);
458       GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph);
459     } else {
460       // Rewrite subgraphs with many MMs into expressions that batch them.
461       BatchMM(graph);
462       GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
463 
464       FuseGraph(graph, true);
465       GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph);
466     }
467 
468     // Run custom post-fusion passes
469     for (const auto& passPair : getCustomPostPasses()) {
470       passPair.first(graph);
471     }
472     GRAPH_DEBUG(
473         "After customPostPasses, before RemoveTensorTypeSpecializations \n",
474         *graph);
475     RemoveTensorTypeSpecializations(graph);
476     GRAPH_DEBUG("After RemoveTensorTypeSpecializations\n", *graph);
477   }
478   GRAPH_DEBUG("End of runNoGradOptimizations\n");
479 }
480 
runProfilingOptimizations(std::shared_ptr<Graph> & copy,size_t remaining_bailout_depth)481 void ProfilingGraphExecutorImpl::runProfilingOptimizations(
482     std::shared_ptr<Graph>& copy,
483     size_t remaining_bailout_depth) {
484   GRAPH_DEBUG("Before runProfilingOptimizations:\n", *copy);
485   if (!getGraphExecutorOptimize()) {
486     runNooptPassPipeline(copy);
487     return;
488   }
489 
490   runPreAutodiffPassPipeline(copy);
491 
492   if (needsGradientInProfilingMode(copy->block())) {
493     auto diff_nodes = CreateAutodiffSubgraphs(
494         copy,
495         getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
496     GRAPH_DEBUG("After CreateAutodiffSubgraphs\n", *copy);
497     size_t idx = 0;
498     for (Node* dnode : diff_nodes) {
499       GRAPH_DEBUG("Optimizing diff node ", idx, " in ", *copy);
500       if (!guardDifferentiableGraph(dnode)) {
501         // if we cannot guard (because of inputs without profiling information),
502         // we re-inline the subgraph and remove the differentiable node
503         GRAPH_DEBUG("Could not guardDifferentiableGraph ", idx, " in ", *copy);
504         idx++;
505         continue;
506       }
507       GRAPH_DEBUG("After guardDifferentiableGraph:\n", *copy);
508       auto diff_graph = std::move(dnode->g(attr::Subgraph));
509       Gradient gradient = differentiate(diff_graph);
510       RemoveTensorTypeSpecializations(gradient.f);
511       ProfilingRecord::removeProfilingNodes(gradient.f->block());
512       GRAPH_DEBUG("Forward graph:\n", *(gradient.f));
513       GRAPH_DEBUG("Backward graph:\n", *(gradient.df));
514       // just like inside autograd.Functions, the forward of a differentiable
515       // graph is essentially in a torch.no_grad context.
516       UpdateDifferentiableGraphRequiresGrad(gradient.f, false);
517       GRAPH_DEBUG("After UpdateDifferentiableGraphRequiresGrad ", *gradient.f);
518       // replaces fallback graphs inserted by TE Fuser
519       replaceFallbackGraphWithFallbackFunction(gradient.f->block());
520       packGradient(gradient, dnode);
521       GRAPH_DEBUG("Finished optimizing diff node ", idx++);
522     }
523     InlineAutodiffSubgraphs(
524         copy,
525         getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
526     replaceFallbackGraphWithFallbackFunction(copy->block());
527     ProfilingRecord::removeProfilingNodes(copy->block());
528     GRAPH_DEBUG(
529         "After InlineAutodiffSubgraphs and Removing Profiling Nodes\n", *copy);
530   } else {
531     runNoGradOptimizations(copy, remaining_bailout_depth);
532   }
533   EliminateDeadCode(copy);
534   GRAPH_DEBUG("After runProfilingOptimizations:\n", *copy);
535 }
536 
runProfilingInsensitiveOptimizations(std::shared_ptr<Graph> & graph)537 void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations(
538     std::shared_ptr<Graph>& graph) {
539   GRAPH_DEBUG(
540       "Before inlining (beginning of runProfilingInsensitiveOptimizations)\n",
541       *graph);
542   // TODO: maybe this can go later in pipeline / directly in autodiff forward
543   // creation
544   if (getGraphExecutorOptimize()) {
545     Inline(*graph);
546   }
547   GRAPH_DEBUG("After inlining, before ClearProfilingInformation\n", *graph);
548   ClearProfilingInformation(graph);
549   GRAPH_DEBUG("After ClearProfilingInformation, before LowerGradOf\n", *graph);
550   LowerGradOf(*graph);
551   GRAPH_DEBUG("After LowerGradOf, before ClearUndefinedness\n", *graph);
552   // clear any residual undefinedness
553   // as double backward graph inputs'
554   // may carry over undefinedness
555   // from profiled backward graphs
556   ClearUndefinedness(graph);
557   // runRequiredPasses
558   {
559     GRAPH_DEBUG("After ClearUndefinedness, before RemoveExpands\n", *graph);
560     RemoveExpands(graph);
561     GRAPH_DEBUG("After RemoveExpands, before CanonicalizeOps\n", *graph);
562     CanonicalizeOps(graph);
563     GRAPH_DEBUG("After CanonicalizeOps, before EliminateDeadCode\n", *graph);
564     EliminateDeadCode(graph);
565   }
566   if (!getGraphExecutorOptimize()) {
567     GRAPH_DEBUG(
568         "After EliminateDeadCode (end of runProfilingInsensitiveOptimizations)\n",
569         *graph);
570     return;
571   }
572 
573   GRAPH_DEBUG("After EliminateDeadCode, before DecomposeOps\n", *graph);
574   DecomposeOps(graph);
575   GRAPH_DEBUG("After DecomposeOps, before ConstantPropagation\n", *graph);
576   ConstantPropagation(graph);
577   GRAPH_DEBUG("After ConstantPropagation, before EliminateDeadCode\n", *graph);
578   EliminateDeadCode(graph);
579   GRAPH_DEBUG(
580       "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);
581   EliminateCommonSubexpression(graph);
582   GRAPH_DEBUG(
583       "After EliminateCommonSubexpression, before ConstantPooling\n", *graph);
584   ConstantPooling(graph);
585   GRAPH_DEBUG("After ConstantPooling, before PeepholeOptimize\n", *graph);
586   PeepholeOptimize(graph);
587   GRAPH_DEBUG("After PeepholeOptimize, before EliminateDeadCode\n", *graph);
588   EliminateDeadCode(graph);
589   GRAPH_DEBUG("After EliminateDeadCode, before LowerSimpleTuples\n", *graph);
590   LowerSimpleTuples(graph);
591   GRAPH_DEBUG("After LowerSimpleTuples, before CheckInplace\n", *graph);
592   CheckInplace(graph);
593   GRAPH_DEBUG(
594       "After CheckInplace (end of runProfilingInsensitiveOptimizations)\n",
595       *graph);
596 }
597 
ProfilingGraphExecutorImpl(const std::shared_ptr<Graph> & graph,std::string function_name)598 ProfilingGraphExecutorImpl::ProfilingGraphExecutorImpl(
599     const std::shared_ptr<Graph>& graph,
600     std::string function_name)
601     : GraphExecutorImplBase(graph, std::move(function_name)) {
602   fusion_strategy_ = getFusionStrategy();
603 }
604 
getInstantiatedBailoutDepth()605 size_t ProfilingGraphExecutorImpl::getInstantiatedBailoutDepth() {
606   // Initialize bailout_depth from command-line flag.
607   size_t depth = 0;
608   for (const auto& pair : fusion_strategy_) {
609     depth += pair.second;
610   }
611   return depth;
612 }
613 
getOptimizedPlanFor(Stack & stack,std::optional<size_t> remaining_bailout_depth)614 const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
615     Stack& stack,
616     std::optional<size_t> remaining_bailout_depth) {
617   GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this);
618 
619   // TODO: instantiate simple executor when getProfilingMode() is false
620   // no opt mode
621   if (!getGraphExecutorOptimize() || !getProfilingMode()) {
622     if (!fallback_plan_) {
623       auto copy = graph->copy();
624       GRAPH_DEBUG(
625           "Before LowerGradOf (beginning of runNooptPassPipeline)\n", *graph);
626       LowerGradOf(*copy);
627       GRAPH_DEBUG("After LowerGradOf, before RemoveExpands\n", *graph);
628       RemoveExpands(copy);
629       fallback_plan_ = ExecutionPlan(copy, function_name_);
630       GRAPH_DUMP("NoOpt Graph: ", copy);
631     }
632     return *fallback_plan_;
633   }
634 
635   // if tensorExprFuserEnabled() returns true we need to persist the very first
636   // time ProfilingGraphExecutorImpl is called, so we can update it correctly
637   // for fallback functions in ProfilingGraphExecutorImpl Else,
638   // getPlanFor(remaining_bailout_depth) is corrected and persisted by the Code
639   // object in interpreter.
640   if (!remaining_bailout_depth_.has_value() || !tensorExprFuserEnabled()) {
641     if (remaining_bailout_depth.has_value()) {
642       remaining_bailout_depth_ = *remaining_bailout_depth;
643     } else {
644       remaining_bailout_depth_ = getInstantiatedBailoutDepth();
645     }
646   }
647 
648   // simple executor
649   if (*remaining_bailout_depth_ == 0) {
650     auto copy = graph->copy();
651     runProfilingInsensitiveOptimizations(copy);
652     GRAPH_DUMP("Optimized SimpleExecutor Graph: ", copy);
653     optimized_plan_ = ExecutionPlan(copy, function_name_);
654     time_optimized_plan_created_ = getNowInSecs();
655     return *optimized_plan_;
656   }
657 
658   bool profiling_record_created_in_this_call = false;
659   // if a profiling graph hasn't been created yet
660   if (!pr_) {
661     auto copy = graph->copy();
662     runProfilingInsensitiveOptimizations(copy);
663     pr_ = ProfilingRecord::instrumentGraph(copy);
664     profiling_record_created_in_this_call = true;
665     // `InsertProfileNodesForSpecializeAutogradZero` profiles a definition vs a
666     // use and it doesn't expect any profile nodes between a graph input and its
667     // consumer, `aten::_grad_sum_to_size`. This means we need to run it first,
668     // before any other pass that could insert `prim::iprofile_value` node on
669     // `aten::_grad_sum_to_size` input.
670     InsertProfileNodesForSpecializeAutogradZero(pr_.get());
671     GRAPH_DUMP("Profiled Graph: ", pr_->graph());
672     profiling_plan_ = ExecutionPlan(pr_->graph(), function_name_);
673     // fall-through
674   }
675 
676   // profile until a graph is ready
677   if (!pr_->ready()) {
678     return *profiling_plan_;
679   }
680 
681   auto copy = pr_->graph()->copy();
682   ProfilingRecord::removeProfileCounter(copy->block());
683   runProfilingOptimizations(copy, *remaining_bailout_depth_);
684   // replaces a fallback graph inserted by
685   // specialize_autogradzero if one exists
686   replaceFallbackGraphWithFallbackFunction(copy->block());
687   runFinalOptimizations(copy);
688   CheckStrictFusion(copy);
689   GRAPH_DUMP("Optimized Graph: ", copy);
690   optimized_plan_ = ExecutionPlan(copy, function_name_);
691   time_optimized_plan_created_ = getNowInSecs();
692   // If the profiled graph was created in this call, then we can release it
693   // right.
694   if (FLAGS_torch_jit_release_profiling_graph_after_optimization &&
695       profiling_record_created_in_this_call) {
696     clearTheGraphCompilationIntermediateGraphs();
697   }
698   return *optimized_plan_;
699 }
700 
getPlanFor(Stack & stack,std::optional<size_t> remaining_bailout_depth)701 const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor(
702     Stack& stack,
703     std::optional<size_t> remaining_bailout_depth) {
704   std::lock_guard<std::mutex> lock(compile_mutex);
705 
706   // IMPORTANT: This is a hot path of calling a torchscript function. Try not to
707   // add any code above this.
708   if (optimized_plan_) {
709     if (FLAGS_torch_jit_release_profiling_graph_after_optimization &&
710         !is_graph_extra_memory_released_) {
711       int32_t now = getNowInSecs();
712       if ((now - time_optimized_plan_created_) >
713           FLAGS_torch_jit_release_profiling_graph_delay_in_seconds) {
714         clearTheGraphCompilationIntermediateGraphs();
715       }
716     }
717     return *optimized_plan_;
718   }
719   // if depth is not set, use
720   return getOptimizedPlanFor(stack, remaining_bailout_depth);
721 }
722 
getDebugState()723 GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() {
724   GraphExecutorState state;
725   TORCH_INTERNAL_ASSERT(optimized_plan_);
726   auto opt_plan = *optimized_plan_;
727   state.execution_plans.emplace(ArgumentSpec{0, 0}, opt_plan);
728   return state;
729 }
730 
insertFallbackFunctionCall(Graph * graph,GraphFunction * func,ArrayRef<Value * > inputs)731 static Node* insertFallbackFunctionCall(
732     Graph* graph,
733     GraphFunction* func,
734     ArrayRef<Value*> inputs) {
735   auto tuple_type = func->graph()->return_node()->input(0)->type();
736   Value* fn_constant = graph->insertNode(graph->create(prim::Constant))
737                            ->s_(attr::name, func->name())
738                            ->i_(Symbol::attr("fallback"), 1)
739                            ->output()
740                            ->setType(FunctionType::create(func));
741   std::vector<Value*> func_call_inputs = {fn_constant};
742   func_call_inputs.insert(func_call_inputs.end(), inputs.begin(), inputs.end());
743   Value* result =
744       graph->insertNode(graph->create(prim::CallFunction, func_call_inputs))
745           ->output()
746           ->setType(tuple_type);
747 
748   auto fun_unpack_tuple = graph->insertNode(graph->createTupleUnpack(result));
749   return fun_unpack_tuple;
750 }
751 
createFallbackPathFunction(Block * b,const std::string & function_name)752 static GraphFunction* createFallbackPathFunction(
753     Block* b,
754     const std::string& function_name) {
755   auto value_map = [](Value* v) { return v; };
756   auto graph = std::make_shared<Graph>();
757   graph->block()->cloneFrom(b, value_map);
758 
759   auto otypes = c10::fmap(
760       graph->return_node()->inputs(), [](Value* v) { return v->type(); });
761   // a GraphFunction call only have one output, so all the outputs
762   // need to be packed into a tuple
763   auto tuple_type = TupleType::create(otypes);
764   auto return_tuple = graph->createTuple(graph->return_node()->inputs());
765   graph->appendNode(return_tuple);
766   for (int i = static_cast<int>(graph->outputs().size()) - 1; i >= 0; i--) {
767     graph->eraseOutput(i);
768   }
769   graph->registerOutput(return_tuple->output());
770   return new GraphFunction(function_name, graph, nullptr);
771 }
772 
replaceFallbackGraphWithFallbackFunction(Block * b)773 void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction(
774     Block* b) {
775   Stack s;
776   for (auto it = b->nodes().begin(); it != b->nodes().end();) {
777     if (it->kind() == prim::FallbackGraph) {
778       auto fallback_func = createFallbackPathFunction(
779           it->g(attr::Subgraph)->block(), "fallback_function");
780       TORCH_INTERNAL_ASSERT(*remaining_bailout_depth_ > 0);
781       GRAPH_DEBUG(
782           "getPlanFor for", getHeader(*it), " ", *remaining_bailout_depth_);
783       fallback_func->get_executor().getPlanFor(
784           s, *remaining_bailout_depth_ - 1);
785       fallback_functions_.emplace_back(fallback_func);
786       WithInsertPoint wip{*it};
787       auto function_call = insertFallbackFunctionCall(
788           b->owningGraph(), fallback_func, it->inputs());
789       for (const auto i : c10::irange(function_call->outputs().size())) {
790         it->output(i)->replaceAllUsesWith(function_call->output(i));
791       }
792       it.destroyCurrent();
793     } else {
794       for (Block* ib : it->blocks()) {
795         replaceFallbackGraphWithFallbackFunction(ib);
796       }
797       it++;
798     }
799   }
800 }
801 
runFinalOptimizations(std::shared_ptr<Graph> & graph)802 void ProfilingGraphExecutorImpl::runFinalOptimizations(
803     std::shared_ptr<Graph>& graph) {
804   AddIfThenElseOp(graph);
805 }
806 
debugFlushCompilationCache()807 void ProfilingGraphExecutorImpl::debugFlushCompilationCache() {
808   std::lock_guard<std::mutex> lock(compile_mutex);
809   pr_.reset();
810   fallback_plan_.reset();
811   profiling_plan_.reset();
812   optimized_plan_.reset();
813   // prevent memory leaks
814   fallback_functions_.clear();
815   remaining_bailout_depth_.reset();
816   // TODO - would be nice to have it initialized in subsequent use
817   fusion_strategy_ = getFusionStrategy();
818   time_optimized_plan_created_ = 0;
819   is_graph_extra_memory_released_ = false;
820 }
821 
clearTheGraphCompilationIntermediateGraphs()822 void ProfilingGraphExecutorImpl::clearTheGraphCompilationIntermediateGraphs() {
823   is_graph_extra_memory_released_ = true;
824   profiling_plan_.reset();
825   fallback_plan_.reset();
826   graph.reset();
827   pr_.reset();
828 }
829 
830 } // namespace torch::jit
831