1 #pragma once 2 #include <c10/util/Flags.h> 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/runtime/graph_executor_impl.h> 5 6 C10_DECLARE_bool(torch_jit_static_then_dynamic); 7 8 C10_DECLARE_bool(torch_jit_always_dynamic); 9 10 namespace torch::jit { 11 12 TORCH_API void runNooptPassPipeline(std::shared_ptr<Graph>& graph); 13 14 struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { 15 ProfilingGraphExecutorImpl( 16 const std::shared_ptr<Graph>& graph, 17 std::string function_name); 18 19 const ExecutionPlan& getPlanFor( 20 Stack& stack, 21 std::optional<size_t> remaining_bailout_depth) override; 22 GraphExecutorState getDebugState() override; 23 ~ProfilingGraphExecutorImpl() override = default; 24 25 void debugFlushCompilationCache(); 26 isOptimizedProfilingGraphExecutorImpl27 bool isOptimized() const override { 28 return optimized_plan_.has_value(); 29 } 30 31 private: 32 const ExecutionPlan& getOptimizedPlanFor( 33 Stack& stack, 34 std::optional<size_t> remaining_bailout_depth); 35 void runProfilingInsensitiveOptimizations(std::shared_ptr<Graph>& graph); 36 void runProfilingOptimizations( 37 std::shared_ptr<Graph>& graph, 38 size_t remaining_depth); 39 void replaceFallbackGraphWithFallbackFunction(Block* b); 40 FusionBehavior getCurrentBehavior(size_t remaining_depth); 41 size_t getInstantiatedBailoutDepth(); 42 void runNoGradOptimizations( 43 std::shared_ptr<Graph>& graph, 44 size_t remaining_bailout_depth); 45 void runFinalOptimizations(std::shared_ptr<Graph>& graph); 46 47 void clearTheGraphCompilationIntermediateGraphs(); 48 49 std::unique_ptr<ProfilingRecord> pr_; 50 std::optional<ExecutionPlan> 51 profiling_plan_; // plan to run in order to profiling the code 52 std::optional<ExecutionPlan> optimized_plan_; 53 FusionStrategy fusion_strategy_; 54 55 // this plan is used if getGraphExecutorOptimize is unset 56 std::optional<ExecutionPlan> fallback_plan_; 57 // fallback functions are inserted for tensorexpr fusion groups 58 // and by specialize_autogradzero. Whenever, at runtime, input 59 // tensor don't match profiled properties, fallback functions are called 60 // They are the deoptimized version of the logic in fusion groups 61 // and/or autograd. 62 // The fallback functions are owned by a GraphExecutor instance 63 // They only exist in the optimized graph which is a private property 64 // of the GraphExecutor and only shared with InterpreterState 65 std::vector<std::unique_ptr<Function>> fallback_functions_; 66 std::optional<size_t> remaining_bailout_depth_; 67 // The time the optimized_plan_ is created. 68 int32_t time_optimized_plan_created_ = 0; 69 // Has the extra memory used by the graph for profiling is released? 70 bool is_graph_extra_memory_released_ = false; 71 }; 72 73 } // namespace torch::jit 74