xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/api/function_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Flags.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/jit/api/function_impl.h>
4 #include <torch/csrc/jit/passes/inliner.h>
5 
6 #include <torch/csrc/jit/frontend/error_report.h>
7 #include <torch/csrc/jit/passes/constant_pooling.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/peephole.h>
10 
11 #ifndef C10_MOBILE
12 #include <ATen/autocast_mode.h>
13 #include <torch/csrc/jit/passes/autocast.h>
14 #endif
15 
16 C10_DEFINE_bool(
17     torch_jit_do_not_store_optimized_graph,
18     false,
19     "Do not store the optimized graph.");
20 
21 namespace torch::jit {
22 namespace {
defaultSchemaFor(const GraphFunction & function)23 c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
24   std::vector<c10::Argument> args;
25   std::vector<c10::Argument> returns;
26   Graph& g = *function.graph();
27   size_t num_inputs = function.num_inputs();
28   for (const auto i : c10::irange(num_inputs)) {
29     const Value* v = g.inputs().at(i);
30     std::string name = v->hasDebugName() ? v->debugNameBase()
31                                          : ("argument_" + std::to_string(i));
32     args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
33   }
34   for (const auto i : c10::irange(g.outputs().size())) {
35     returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
36   }
37   return {function.name(), "", std::move(args), std::move(returns)};
38 }
39 
40 template <typename T, typename F>
tryToGraphFunctionImpl(F & function)41 T* tryToGraphFunctionImpl(F& function) noexcept {
42   if (!function.isGraphFunction()) {
43     return nullptr;
44   }
45 
46   return static_cast<T*>(&function);
47 }
48 
49 template <typename T, typename F>
toGraphFunctionImpl(F & function)50 T& toGraphFunctionImpl(F& function) {
51   if (auto* g = tryToGraphFunctionImpl<T>(function)) {
52     return *g;
53   }
54 
55   TORCH_INTERNAL_ASSERT(
56       false,
57       "Failed to downcast a Function to a GraphFunction. "
58       "This probably indicates that the JIT calling context needs a "
59       "special case on tryToGraphFunction() instead.");
60 }
61 
62 } // namespace
63 
placeholderCreator(GraphFunction &)64 static void placeholderCreator(GraphFunction&) {
65   throw RecursiveMethodCallError();
66 }
67 
run(Stack & stack)68 void GraphFunction::run(Stack& stack) {
69   C10_LOG_EVENT_SAMPLED(run, qualname().qualifiedName(), stack);
70   get_executor().run(stack);
71 }
72 
runAsync(Stack & stack,TaskLauncher taskLauncher)73 c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
74     Stack& stack,
75     TaskLauncher taskLauncher) {
76   return get_executor().runAsync(stack, std::move(taskLauncher));
77 }
78 
ensure_defined()79 void GraphFunction::ensure_defined() {
80   if (function_creator_) {
81     auto creator = function_creator_;
82     function_creator_ = placeholderCreator;
83     creator(*this);
84     function_creator_ = nullptr;
85   }
86   check_single_output();
87 }
88 
getSchema() const89 const c10::FunctionSchema& GraphFunction::getSchema() const {
90   if (schema_ == nullptr) {
91     schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
92   }
93   return *schema_;
94 }
95 
optimized_graph() const96 std::shared_ptr<Graph> GraphFunction::optimized_graph() const {
97   std::lock_guard<std::recursive_mutex> lock(compile_mutex);
98   decltype(optimized_graphs_)::value_type graph;
99   auto& graph_ref = !FLAGS_torch_jit_do_not_store_optimized_graph
100       ? optimized_graphs_[currentSpecialization()]
101       : graph;
102   if (graph_ref) {
103     return graph_ref;
104   }
105   graph_ref = graph_->copy();
106   if (getGraphExecutorOptimize()) {
107     preoptimizeGraph(graph_ref, force_no_amp_);
108   }
109   return graph_ref;
110 }
111 
currentSpecialization() const112 GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
113   if (force_no_amp_) {
114     return SpecializationKey::AutocastOff;
115   }
116 #ifdef C10_MOBILE
117   // disabling autodiff pass for mobile build since autocast APIs don't exist
118   return SpecializationKey::AutocastOff;
119 #else
120   bool cpu_enabled = at::autocast::is_autocast_enabled(at::kCPU);
121   bool gpu_enabled = at::autocast::is_autocast_enabled(at::kCUDA);
122   if (cpu_enabled && gpu_enabled) {
123     return SpecializationKey::CpuGpuAutocastOn;
124   } else if (!cpu_enabled && !gpu_enabled) {
125     return SpecializationKey::AutocastOff;
126   } else {
127     return gpu_enabled ? SpecializationKey::GpuAutocastOn
128                        : SpecializationKey::CpuAutocastOn;
129   }
130 #endif
131 }
132 
preoptimizeGraph(std::shared_ptr<Graph> & graph,bool disable_autocast)133 void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast) {
134   Inline(*graph);
135 
136   // Peephole Optimize cleans up many "is None" checks and creates constant prop
137   // opportunities
138   PeepholeOptimize(graph, true);
139 
140   // AliasDb construction can be slow, so run it just on immutable types
141   // to clean up constant Ifs & other easy wins
142   ConstantPropagationImmutableTypes(graph);
143 
144 #ifndef C10_MOBILE
145   // Inject casts for automatic mixed precision
146   //
147   // TODO: Ideally, this pass could run earlier, before inlining
148   //  or any other optimizations. That setup is preferable because:
149   //  1. The AMP pass would be self-contained and function independently
150   //     of the any optimizations
151   //  2. AMP transformations would benefit from followup passes's cleanup
152   //
153   if (!disable_autocast) {
154     Autocast(graph);
155   }
156 #endif
157 
158   ConstantPooling(graph);
159 }
160 
tryToGraphFunction(Function & function)161 GraphFunction* tryToGraphFunction(Function& function) noexcept {
162   return tryToGraphFunctionImpl<GraphFunction>(function);
163 }
164 
toGraphFunction(Function & function)165 GraphFunction& toGraphFunction(Function& function) {
166   return toGraphFunctionImpl<GraphFunction>(function);
167 }
168 
toGraphFunction(const Function & function)169 const GraphFunction& toGraphFunction(const Function& function) {
170   return toGraphFunctionImpl<const GraphFunction>(function);
171 }
172 
173 } // namespace torch::jit
174