xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/api/function_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/function.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/runtime/graph_executor.h>
6 
7 namespace torch::jit {
8 
9 struct TORCH_API GraphFunction : public Function {
10   GraphFunction(
11       c10::QualifiedName name,
12       std::shared_ptr<Graph> graph,
13       std::function<void(GraphFunction&)> function_creator,
14       std::optional<ExecutorExecutionMode> executor_execution_mode =
15           std::nullopt)
name_GraphFunction16       : name_(std::move(name)),
17         graph_(std::move(graph)),
18         executor_execution_mode_(executor_execution_mode),
19         function_creator_(std::move(function_creator)) {}
20 
isGraphFunctionGraphFunction21   bool isGraphFunction() const override {
22     return true;
23   }
24 
25   void run(Stack& stack) override;
26 
function_creatorGraphFunction27   std::function<void(GraphFunction&)> function_creator() const {
28     return function_creator_;
29   }
30 
31   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
32       Stack& stack,
33       TaskLauncher taskLauncher = at::launch) override;
34 
graphGraphFunction35   std::shared_ptr<Graph> graph() const {
36     return graph_;
37   }
38 
39   std::shared_ptr<Graph> optimized_graph() const;
40 
qualnameGraphFunction41   const c10::QualifiedName& qualname() const override {
42     return name_;
43   }
44 
45   // private/unstable api. sets the initial execution mode
46   // will not affect executor if there is an existing executor
47   // created for this function
_set_initial_executor_execution_modeGraphFunction48   void _set_initial_executor_execution_mode(ExecutorExecutionMode mode) {
49     executor_execution_mode_ = mode;
50   }
51   // private/unstable api. sets flag of whether or not to ignore amp.
52   // will not affect executor if there is an existing executor
53   // created for this function
_set_ignore_ampGraphFunction54   void _set_ignore_amp(bool ignore_amp) {
55     force_no_amp_ = ignore_amp;
56   }
57 
58   // if this isn't yet defined, run its method_creator function
59   void ensure_defined() override;
60 
num_inputsGraphFunction61   size_t num_inputs() const override {
62     return graph()->inputs().size();
63   }
64 
setSchemaGraphFunction65   Function& setSchema(FunctionSchema schema) override {
66     schema_ = std::make_unique<FunctionSchema>(std::move(schema));
67     return *this;
68   }
69 
70   const FunctionSchema& getSchema() const override;
71 
getDebugStateGraphFunction72   GraphExecutorState getDebugState() {
73     return get_executor().getDebugState();
74   }
75 
is_optimizedGraphFunction76   bool is_optimized() const {
77     TORCH_WARN(
78         "GraphFunction::is_optimized() is deprecated and always returns true. "
79         "Please use getGraphExecutorOptimize()");
80     return true;
81   }
82 
check_single_outputGraphFunction83   void check_single_output() {
84     TORCH_CHECK(
85         graph()->outputs().size() == 1,
86         "Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
87   }
88 
get_executorGraphFunction89   GraphExecutor& get_executor() {
90     ensure_defined();
91     std::lock_guard<std::recursive_mutex> lock(compile_mutex);
92     auto& executor = executors_[currentSpecialization()];
93     if (executor) {
94       return *executor;
95     }
96     check_single_output();
97     const std::string& name = name_.name();
98     std::shared_ptr<Graph> opt_graph = optimized_graph();
99     if (!executor_execution_mode_) {
100       executor = GraphExecutor(opt_graph, name);
101     } else {
102       executor = GraphExecutor(opt_graph, name, *executor_execution_mode_);
103     }
104     return *executor;
105   }
106 
107   using Function::call;
callGraphFunction108   bool call(
109       Stack& stack,
110       std::optional<size_t> bailOut,
111       c10::function_ref<void(const Code&)> f) override {
112     f(get_executor().getPlanFor(stack, bailOut).code);
113     return true;
114   }
115 
clear_optimized_graphsGraphFunction116   void clear_optimized_graphs() {
117     optimized_graphs_.fill(nullptr);
118   }
119 
120  private:
121   enum SpecializationKey {
122     AutocastOff,
123     CpuAutocastOn,
124     GpuAutocastOn,
125     CpuGpuAutocastOn,
126 
127     // This provides the number of specializations
128     // (Must be last entry)
129     TotalCount
130   };
131 
132   SpecializationKey currentSpecialization() const;
133 
134  private:
135   c10::QualifiedName name_;
136   // The original, non-optimized graph
137   std::shared_ptr<Graph> graph_; // for debugging and for inlining
138 
139   // allows users to specify Simple/Profiling Executor for function
140   // TODO: add more executors
141   mutable std::optional<ExecutorExecutionMode> executor_execution_mode_;
142 
143   // if invoked on a graph that has already traced through amp
144   // don't invoke amp pass
145   mutable bool force_no_amp_ = false;
146   // Optimized graph, computed lazily. Used for inlining.
147   mutable std::array<std::shared_ptr<Graph>, SpecializationKey::TotalCount>
148       optimized_graphs_;
149 
150   // GraphFunctions are invokable from multiple threads, so this lock needs to
151   // be held when we're initializing graph executor for the first time or
152   // computing the optimized graph. We're using reentrant mutex so that we don't
153   // need to worry about causing a deadlock by calling one method from another
154   // (e.g. optimized_graph() from get_executor()).
155   mutable std::recursive_mutex compile_mutex;
156 
157   // executor_[0] - autocast off
158   // executor_[1] - autocast cpu on
159   // executor_[2] - autocast gpu on
160   // executor_[3] - autocast cpu & gpu on
161   std::array<std::optional<GraphExecutor>, SpecializationKey::TotalCount>
162       executors_;
163 
164   // an optional function that actually creates the method when
165   // ensure_defined() is called. This is used by the compiler so
166   // that it can construct methods out of order
167   std::function<void(GraphFunction&)> function_creator_;
168 
169   // if absent, then we generate a default schema based on the graph
170   // mutable because getSchema caches the default schema if one is requested
171   // before a call to setSchema
172   mutable std::unique_ptr<FunctionSchema> schema_;
173 };
174 
175 // Short hands for dynamic_cast<GraphFunction*>.
176 TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
177 TORCH_API GraphFunction& toGraphFunction(Function&);
178 TORCH_API const GraphFunction& toGraphFunction(const Function&);
179 
180 } // namespace torch::jit
181