xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/profiling_graph_executor_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/Flags.h>
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
5 
6 C10_DECLARE_bool(torch_jit_static_then_dynamic);
7 
8 C10_DECLARE_bool(torch_jit_always_dynamic);
9 
10 namespace torch::jit {
11 
12 TORCH_API void runNooptPassPipeline(std::shared_ptr<Graph>& graph);
13 
14 struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
15   ProfilingGraphExecutorImpl(
16       const std::shared_ptr<Graph>& graph,
17       std::string function_name);
18 
19   const ExecutionPlan& getPlanFor(
20       Stack& stack,
21       std::optional<size_t> remaining_bailout_depth) override;
22   GraphExecutorState getDebugState() override;
23   ~ProfilingGraphExecutorImpl() override = default;
24 
25   void debugFlushCompilationCache();
26 
isOptimizedProfilingGraphExecutorImpl27   bool isOptimized() const override {
28     return optimized_plan_.has_value();
29   }
30 
31  private:
32   const ExecutionPlan& getOptimizedPlanFor(
33       Stack& stack,
34       std::optional<size_t> remaining_bailout_depth);
35   void runProfilingInsensitiveOptimizations(std::shared_ptr<Graph>& graph);
36   void runProfilingOptimizations(
37       std::shared_ptr<Graph>& graph,
38       size_t remaining_depth);
39   void replaceFallbackGraphWithFallbackFunction(Block* b);
40   FusionBehavior getCurrentBehavior(size_t remaining_depth);
41   size_t getInstantiatedBailoutDepth();
42   void runNoGradOptimizations(
43       std::shared_ptr<Graph>& graph,
44       size_t remaining_bailout_depth);
45   void runFinalOptimizations(std::shared_ptr<Graph>& graph);
46 
47   void clearTheGraphCompilationIntermediateGraphs();
48 
49   std::unique_ptr<ProfilingRecord> pr_;
50   std::optional<ExecutionPlan>
51       profiling_plan_; // plan to run in order to profiling the code
52   std::optional<ExecutionPlan> optimized_plan_;
53   FusionStrategy fusion_strategy_;
54 
55   // this plan is used if getGraphExecutorOptimize is unset
56   std::optional<ExecutionPlan> fallback_plan_;
57   // fallback functions are inserted for tensorexpr fusion groups
58   // and by specialize_autogradzero. Whenever, at runtime, input
59   // tensor don't match profiled properties, fallback functions are called
60   // They are the deoptimized version of the logic in fusion groups
61   // and/or autograd.
62   // The fallback functions are owned by a GraphExecutor instance
63   // They only exist in the optimized graph which is a private property
64   // of the GraphExecutor and only shared with InterpreterState
65   std::vector<std::unique_ptr<Function>> fallback_functions_;
66   std::optional<size_t> remaining_bailout_depth_;
67   // The time the optimized_plan_ is created.
68   int32_t time_optimized_plan_created_ = 0;
69   // Has the extra memory used by the graph for profiling is released?
70   bool is_graph_extra_memory_released_ = false;
71 };
72 
73 } // namespace torch::jit
74