1 #pragma once 2 3 #include <ATen/core/function.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <torch/csrc/jit/runtime/graph_executor.h> 6 7 namespace torch::jit { 8 9 struct TORCH_API GraphFunction : public Function { 10 GraphFunction( 11 c10::QualifiedName name, 12 std::shared_ptr<Graph> graph, 13 std::function<void(GraphFunction&)> function_creator, 14 std::optional<ExecutorExecutionMode> executor_execution_mode = 15 std::nullopt) name_GraphFunction16 : name_(std::move(name)), 17 graph_(std::move(graph)), 18 executor_execution_mode_(executor_execution_mode), 19 function_creator_(std::move(function_creator)) {} 20 isGraphFunctionGraphFunction21 bool isGraphFunction() const override { 22 return true; 23 } 24 25 void run(Stack& stack) override; 26 function_creatorGraphFunction27 std::function<void(GraphFunction&)> function_creator() const { 28 return function_creator_; 29 } 30 31 c10::intrusive_ptr<c10::ivalue::Future> runAsync( 32 Stack& stack, 33 TaskLauncher taskLauncher = at::launch) override; 34 graphGraphFunction35 std::shared_ptr<Graph> graph() const { 36 return graph_; 37 } 38 39 std::shared_ptr<Graph> optimized_graph() const; 40 qualnameGraphFunction41 const c10::QualifiedName& qualname() const override { 42 return name_; 43 } 44 45 // private/unstable api. sets the initial execution mode 46 // will not affect executor if there is an existing executor 47 // created for this function _set_initial_executor_execution_modeGraphFunction48 void _set_initial_executor_execution_mode(ExecutorExecutionMode mode) { 49 executor_execution_mode_ = mode; 50 } 51 // private/unstable api. sets flag of whether or not to ignore amp. 52 // will not affect executor if there is an existing executor 53 // created for this function _set_ignore_ampGraphFunction54 void _set_ignore_amp(bool ignore_amp) { 55 force_no_amp_ = ignore_amp; 56 } 57 58 // if this isn't yet defined, run its method_creator function 59 void ensure_defined() override; 60 num_inputsGraphFunction61 size_t num_inputs() const override { 62 return graph()->inputs().size(); 63 } 64 setSchemaGraphFunction65 Function& setSchema(FunctionSchema schema) override { 66 schema_ = std::make_unique<FunctionSchema>(std::move(schema)); 67 return *this; 68 } 69 70 const FunctionSchema& getSchema() const override; 71 getDebugStateGraphFunction72 GraphExecutorState getDebugState() { 73 return get_executor().getDebugState(); 74 } 75 is_optimizedGraphFunction76 bool is_optimized() const { 77 TORCH_WARN( 78 "GraphFunction::is_optimized() is deprecated and always returns true. " 79 "Please use getGraphExecutorOptimize()"); 80 return true; 81 } 82 check_single_outputGraphFunction83 void check_single_output() { 84 TORCH_CHECK( 85 graph()->outputs().size() == 1, 86 "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs"); 87 } 88 get_executorGraphFunction89 GraphExecutor& get_executor() { 90 ensure_defined(); 91 std::lock_guard<std::recursive_mutex> lock(compile_mutex); 92 auto& executor = executors_[currentSpecialization()]; 93 if (executor) { 94 return *executor; 95 } 96 check_single_output(); 97 const std::string& name = name_.name(); 98 std::shared_ptr<Graph> opt_graph = optimized_graph(); 99 if (!executor_execution_mode_) { 100 executor = GraphExecutor(opt_graph, name); 101 } else { 102 executor = GraphExecutor(opt_graph, name, *executor_execution_mode_); 103 } 104 return *executor; 105 } 106 107 using Function::call; callGraphFunction108 bool call( 109 Stack& stack, 110 std::optional<size_t> bailOut, 111 c10::function_ref<void(const Code&)> f) override { 112 f(get_executor().getPlanFor(stack, bailOut).code); 113 return true; 114 } 115 clear_optimized_graphsGraphFunction116 void clear_optimized_graphs() { 117 optimized_graphs_.fill(nullptr); 118 } 119 120 private: 121 enum SpecializationKey { 122 AutocastOff, 123 CpuAutocastOn, 124 GpuAutocastOn, 125 CpuGpuAutocastOn, 126 127 // This provides the number of specializations 128 // (Must be last entry) 129 TotalCount 130 }; 131 132 SpecializationKey currentSpecialization() const; 133 134 private: 135 c10::QualifiedName name_; 136 // The original, non-optimized graph 137 std::shared_ptr<Graph> graph_; // for debugging and for inlining 138 139 // allows users to specify Simple/Profiling Executor for function 140 // TODO: add more executors 141 mutable std::optional<ExecutorExecutionMode> executor_execution_mode_; 142 143 // if invoked on a graph that has already traced through amp 144 // don't invoke amp pass 145 mutable bool force_no_amp_ = false; 146 // Optimized graph, computed lazily. Used for inlining. 147 mutable std::array<std::shared_ptr<Graph>, SpecializationKey::TotalCount> 148 optimized_graphs_; 149 150 // GraphFunctions are invokable from multiple threads, so this lock needs to 151 // be held when we're initializing graph executor for the first time or 152 // computing the optimized graph. We're using reentrant mutex so that we don't 153 // need to worry about causing a deadlock by calling one method from another 154 // (e.g. optimized_graph() from get_executor()). 155 mutable std::recursive_mutex compile_mutex; 156 157 // executor_[0] - autocast off 158 // executor_[1] - autocast cpu on 159 // executor_[2] - autocast gpu on 160 // executor_[3] - autocast cpu & gpu on 161 std::array<std::optional<GraphExecutor>, SpecializationKey::TotalCount> 162 executors_; 163 164 // an optional function that actually creates the method when 165 // ensure_defined() is called. This is used by the compiler so 166 // that it can construct methods out of order 167 std::function<void(GraphFunction&)> function_creator_; 168 169 // if absent, then we generate a default schema based on the graph 170 // mutable because getSchema caches the default schema if one is requested 171 // before a call to setSchema 172 mutable std::unique_ptr<FunctionSchema> schema_; 173 }; 174 175 // Short hands for dynamic_cast<GraphFunction*>. 176 TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept; 177 TORCH_API GraphFunction& toGraphFunction(Function&); 178 TORCH_API const GraphFunction& toGraphFunction(const Function&); 179 180 } // namespace torch::jit 181