xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/graph_executor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <memory>
5 
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/python/update_graph_executor_opt.h>
8 #include <torch/csrc/jit/runtime/argument_spec.h>
9 #include <torch/csrc/jit/runtime/interpreter.h>
10 #include <torch/csrc/jit/runtime/variable_tensor_list.h>
11 
12 C10_DECLARE_bool(torch_jit_enable_new_executor);
13 
14 C10_DECLARE_bool(torch_jit_execution_plan_reuse_code_graph);
15 
16 namespace torch::jit {
17 struct GraphExecutorState;
18 struct Code;
19 
20 enum ExecutorExecutionMode {
21   SIMPLE,
22   PROFILING,
23 };
24 
25 struct ExecutionPlan {
26   ExecutionPlan() = default;
ExecutionPlanExecutionPlan27   ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name)
28       : code(graph, std::move(function_name)),
29         graph(
30             FLAGS_torch_jit_execution_plan_reuse_code_graph
31                 ? code.graph()
32                 : std::move(graph)) {}
33 
34   operator bool() const {
35     return static_cast<bool>(graph);
36   }
37 
38   Code code;
39   std::shared_ptr<Graph> graph;
40 };
41 
42 // Notice that those structs don't manage lifetime of their members.
43 // They are only valid only right after you call getDebugState() and should
44 // never be used again once another GraphExecutor function is called.
45 
46 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
47 struct GraphExecutorState {
48   const Graph* graph = nullptr;
49   ExecutionPlan fallback; // XXX: members of this field are optional
50   std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans;
51 };
52 
53 struct TORCH_API EnableProfilingGuard {
54   EnableProfilingGuard();
55   ~EnableProfilingGuard();
56 
57  private:
58   bool old_executor_mode = false;
59   bool old_get_optimize = false;
60 };
61 
62 struct GraphExecutorImplBase;
63 struct TORCH_API GraphExecutor {
64   GraphExecutor() = default;
65   GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name);
66 
67   GraphExecutor(
68       const std::shared_ptr<Graph>& graph,
69       std::string function_name,
70       ExecutorExecutionMode executor_mode);
71 
72   void run(Stack& inputs);
73   c10::intrusive_ptr<Future> runAsync(
74       Stack& stack,
75       TaskLauncher taskLauncher = at::launch);
76 
77   // `remaining_bailout_depth` stands for the maximum number of profiled and
78   // specialized recompilations allowed for the current `GraphExecutor`. if
79   // remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any
80   // profiling and specialization. This is also equivalent to the
81   // SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0,
82   // `GraphExecutor` will profile and specialize its input graph based on the
83   // profiled information whenever a bailout check is failed/triggered, a new
84   // `GraphExecutor` will be created. This new `GraphExecutor`'s
85   // remaining_bailout_depth will be reduced by 1.
86   // If no bailout depth is passed, the depth will be initialized from the
87   // current global fusion strategy settings.
88   const ExecutionPlan& getPlanFor(
89       Stack& inputs,
90       std::optional<size_t> remaining_bailout_depth = std::nullopt);
91   GraphExecutorState getDebugState();
92 
93   void debugFlushCompilationCache();
94 
95   bool isOptimized() const;
96 
97  private:
98   std::shared_ptr<GraphExecutorImplBase> pImpl;
99 };
100 
101 TORCH_API Node* replaceBlockWithFallbackGraph(
102     Block* b,
103     ArrayRef<Value*> inputs);
104 
105 // These passes need to run before it is valid to pass to the interpreter
106 // regardless of whether sizes have been specialized or not.
107 TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
108 
109 TORCH_API void debugSetFusionGroupInlining(bool state);
110 TORCH_API bool getFusionGroupInlining();
111 
112 TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
113 TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
114 
115 TORCH_API std::atomic<bool>& getProfilingMode();
116 TORCH_API std::atomic<bool>& getExecutorMode();
117 TORCH_API std::atomic<size_t>& getNumProfiledRuns();
118 TORCH_API size_t getBailoutDepth();
119 TORCH_API bool IsNewExecutorEnabled();
120 
121 struct TORCH_API GraphOptimizerEnabledGuard {
GraphOptimizerEnabledGuardGraphOptimizerEnabledGuard122   GraphOptimizerEnabledGuard(bool state)
123       : old_state_(getGraphExecutorOptimize()) {
124     setGraphExecutorOptimize(state);
125   }
126 
~GraphOptimizerEnabledGuardGraphOptimizerEnabledGuard127   ~GraphOptimizerEnabledGuard() {
128     setGraphExecutorOptimize(old_state_);
129   }
130 
131   bool old_state_;
132 };
133 
134 namespace detail {
135 
136 GraphExecutor* getGradExecutor(Operation& op);
137 
138 GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op);
139 
140 // for debugging information we expose a way to get the last actually
141 // run graph. Previous approaches allowed querying the GraphExecutor
142 // for what graph it would run in certain circumstances (graphFor), but
143 // this is fragile because we sometimes change how these decisions are made.
144 // This interface still allows our tests to look at optimized graphs, but
145 // with less plumbing.
146 } // namespace detail
147 
148 } // namespace torch::jit
149