1 #pragma once 2 3 #include <atomic> 4 #include <memory> 5 6 #include <torch/csrc/jit/ir/ir.h> 7 #include <torch/csrc/jit/python/update_graph_executor_opt.h> 8 #include <torch/csrc/jit/runtime/argument_spec.h> 9 #include <torch/csrc/jit/runtime/interpreter.h> 10 #include <torch/csrc/jit/runtime/variable_tensor_list.h> 11 12 C10_DECLARE_bool(torch_jit_enable_new_executor); 13 14 C10_DECLARE_bool(torch_jit_execution_plan_reuse_code_graph); 15 16 namespace torch::jit { 17 struct GraphExecutorState; 18 struct Code; 19 20 enum ExecutorExecutionMode { 21 SIMPLE, 22 PROFILING, 23 }; 24 25 struct ExecutionPlan { 26 ExecutionPlan() = default; ExecutionPlanExecutionPlan27 ExecutionPlan(std::shared_ptr<Graph> graph, std::string function_name) 28 : code(graph, std::move(function_name)), 29 graph( 30 FLAGS_torch_jit_execution_plan_reuse_code_graph 31 ? code.graph() 32 : std::move(graph)) {} 33 34 operator bool() const { 35 return static_cast<bool>(graph); 36 } 37 38 Code code; 39 std::shared_ptr<Graph> graph; 40 }; 41 42 // Notice that those structs don't manage lifetime of their members. 43 // They are only valid only right after you call getDebugState() and should 44 // never be used again once another GraphExecutor function is called. 45 46 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 47 struct GraphExecutorState { 48 const Graph* graph = nullptr; 49 ExecutionPlan fallback; // XXX: members of this field are optional 50 std::unordered_map<ArgumentSpec, ExecutionPlan> execution_plans; 51 }; 52 53 struct TORCH_API EnableProfilingGuard { 54 EnableProfilingGuard(); 55 ~EnableProfilingGuard(); 56 57 private: 58 bool old_executor_mode = false; 59 bool old_get_optimize = false; 60 }; 61 62 struct GraphExecutorImplBase; 63 struct TORCH_API GraphExecutor { 64 GraphExecutor() = default; 65 GraphExecutor(const std::shared_ptr<Graph>& graph, std::string function_name); 66 67 GraphExecutor( 68 const std::shared_ptr<Graph>& graph, 69 std::string function_name, 70 ExecutorExecutionMode executor_mode); 71 72 void run(Stack& inputs); 73 c10::intrusive_ptr<Future> runAsync( 74 Stack& stack, 75 TaskLauncher taskLauncher = at::launch); 76 77 // `remaining_bailout_depth` stands for the maximum number of profiled and 78 // specialized recompilations allowed for the current `GraphExecutor`. if 79 // remaining_bailout_depth is equal to 0, `GraphExecutor` won't perform any 80 // profiling and specialization. This is also equivalent to the 81 // SIMPLE_EXECUTOR mode. if remaining_bailout_depth is greater than 0, 82 // `GraphExecutor` will profile and specialize its input graph based on the 83 // profiled information whenever a bailout check is failed/triggered, a new 84 // `GraphExecutor` will be created. This new `GraphExecutor`'s 85 // remaining_bailout_depth will be reduced by 1. 86 // If no bailout depth is passed, the depth will be initialized from the 87 // current global fusion strategy settings. 88 const ExecutionPlan& getPlanFor( 89 Stack& inputs, 90 std::optional<size_t> remaining_bailout_depth = std::nullopt); 91 GraphExecutorState getDebugState(); 92 93 void debugFlushCompilationCache(); 94 95 bool isOptimized() const; 96 97 private: 98 std::shared_ptr<GraphExecutorImplBase> pImpl; 99 }; 100 101 TORCH_API Node* replaceBlockWithFallbackGraph( 102 Block* b, 103 ArrayRef<Value*> inputs); 104 105 // These passes need to run before it is valid to pass to the interpreter 106 // regardless of whether sizes have been specialized or not. 107 TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g); 108 109 TORCH_API void debugSetFusionGroupInlining(bool state); 110 TORCH_API bool getFusionGroupInlining(); 111 112 TORCH_API void debugSetAutodiffSubgraphInlining(bool state); 113 TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph(); 114 115 TORCH_API std::atomic<bool>& getProfilingMode(); 116 TORCH_API std::atomic<bool>& getExecutorMode(); 117 TORCH_API std::atomic<size_t>& getNumProfiledRuns(); 118 TORCH_API size_t getBailoutDepth(); 119 TORCH_API bool IsNewExecutorEnabled(); 120 121 struct TORCH_API GraphOptimizerEnabledGuard { GraphOptimizerEnabledGuardGraphOptimizerEnabledGuard122 GraphOptimizerEnabledGuard(bool state) 123 : old_state_(getGraphExecutorOptimize()) { 124 setGraphExecutorOptimize(state); 125 } 126 ~GraphOptimizerEnabledGuardGraphOptimizerEnabledGuard127 ~GraphOptimizerEnabledGuard() { 128 setGraphExecutorOptimize(old_state_); 129 } 130 131 bool old_state_; 132 }; 133 134 namespace detail { 135 136 GraphExecutor* getGradExecutor(Operation& op); 137 138 GraphExecutor* getDifferentiableGraphOpExecutor(Operation& op); 139 140 // for debugging information we expose a way to get the last actually 141 // run graph. Previous approaches allowed querying the GraphExecutor 142 // for what graph it would run in certain circumstances (graphFor), but 143 // this is fragile because we sometimes change how these decisions are made. 144 // This interface still allows our tests to look at optimized graphs, but 145 // with less plumbing. 146 } // namespace detail 147 148 } // namespace torch::jit 149