1 #include <c10/util/Flags.h>
2 #include <c10/util/irange.h>
3 #include <torch/csrc/jit/api/function_impl.h>
4 #include <torch/csrc/jit/passes/inliner.h>
5
6 #include <torch/csrc/jit/frontend/error_report.h>
7 #include <torch/csrc/jit/passes/constant_pooling.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/peephole.h>
10
11 #ifndef C10_MOBILE
12 #include <ATen/autocast_mode.h>
13 #include <torch/csrc/jit/passes/autocast.h>
14 #endif
15
16 C10_DEFINE_bool(
17 torch_jit_do_not_store_optimized_graph,
18 false,
19 "Do not store the optimized graph.");
20
21 namespace torch::jit {
22 namespace {
defaultSchemaFor(const GraphFunction & function)23 c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
24 std::vector<c10::Argument> args;
25 std::vector<c10::Argument> returns;
26 Graph& g = *function.graph();
27 size_t num_inputs = function.num_inputs();
28 for (const auto i : c10::irange(num_inputs)) {
29 const Value* v = g.inputs().at(i);
30 std::string name = v->hasDebugName() ? v->debugNameBase()
31 : ("argument_" + std::to_string(i));
32 args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
33 }
34 for (const auto i : c10::irange(g.outputs().size())) {
35 returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
36 }
37 return {function.name(), "", std::move(args), std::move(returns)};
38 }
39
40 template <typename T, typename F>
tryToGraphFunctionImpl(F & function)41 T* tryToGraphFunctionImpl(F& function) noexcept {
42 if (!function.isGraphFunction()) {
43 return nullptr;
44 }
45
46 return static_cast<T*>(&function);
47 }
48
49 template <typename T, typename F>
toGraphFunctionImpl(F & function)50 T& toGraphFunctionImpl(F& function) {
51 if (auto* g = tryToGraphFunctionImpl<T>(function)) {
52 return *g;
53 }
54
55 TORCH_INTERNAL_ASSERT(
56 false,
57 "Failed to downcast a Function to a GraphFunction. "
58 "This probably indicates that the JIT calling context needs a "
59 "special case on tryToGraphFunction() instead.");
60 }
61
62 } // namespace
63
placeholderCreator(GraphFunction &)64 static void placeholderCreator(GraphFunction&) {
65 throw RecursiveMethodCallError();
66 }
67
run(Stack & stack)68 void GraphFunction::run(Stack& stack) {
69 C10_LOG_EVENT_SAMPLED(run, qualname().qualifiedName(), stack);
70 get_executor().run(stack);
71 }
72
runAsync(Stack & stack,TaskLauncher taskLauncher)73 c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
74 Stack& stack,
75 TaskLauncher taskLauncher) {
76 return get_executor().runAsync(stack, std::move(taskLauncher));
77 }
78
ensure_defined()79 void GraphFunction::ensure_defined() {
80 if (function_creator_) {
81 auto creator = function_creator_;
82 function_creator_ = placeholderCreator;
83 creator(*this);
84 function_creator_ = nullptr;
85 }
86 check_single_output();
87 }
88
getSchema() const89 const c10::FunctionSchema& GraphFunction::getSchema() const {
90 if (schema_ == nullptr) {
91 schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
92 }
93 return *schema_;
94 }
95
optimized_graph() const96 std::shared_ptr<Graph> GraphFunction::optimized_graph() const {
97 std::lock_guard<std::recursive_mutex> lock(compile_mutex);
98 decltype(optimized_graphs_)::value_type graph;
99 auto& graph_ref = !FLAGS_torch_jit_do_not_store_optimized_graph
100 ? optimized_graphs_[currentSpecialization()]
101 : graph;
102 if (graph_ref) {
103 return graph_ref;
104 }
105 graph_ref = graph_->copy();
106 if (getGraphExecutorOptimize()) {
107 preoptimizeGraph(graph_ref, force_no_amp_);
108 }
109 return graph_ref;
110 }
111
currentSpecialization() const112 GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
113 if (force_no_amp_) {
114 return SpecializationKey::AutocastOff;
115 }
116 #ifdef C10_MOBILE
117 // disabling autodiff pass for mobile build since autocast APIs don't exist
118 return SpecializationKey::AutocastOff;
119 #else
120 bool cpu_enabled = at::autocast::is_autocast_enabled(at::kCPU);
121 bool gpu_enabled = at::autocast::is_autocast_enabled(at::kCUDA);
122 if (cpu_enabled && gpu_enabled) {
123 return SpecializationKey::CpuGpuAutocastOn;
124 } else if (!cpu_enabled && !gpu_enabled) {
125 return SpecializationKey::AutocastOff;
126 } else {
127 return gpu_enabled ? SpecializationKey::GpuAutocastOn
128 : SpecializationKey::CpuAutocastOn;
129 }
130 #endif
131 }
132
preoptimizeGraph(std::shared_ptr<Graph> & graph,bool disable_autocast)133 void preoptimizeGraph(std::shared_ptr<Graph>& graph, bool disable_autocast) {
134 Inline(*graph);
135
136 // Peephole Optimize cleans up many "is None" checks and creates constant prop
137 // opportunities
138 PeepholeOptimize(graph, true);
139
140 // AliasDb construction can be slow, so run it just on immutable types
141 // to clean up constant Ifs & other easy wins
142 ConstantPropagationImmutableTypes(graph);
143
144 #ifndef C10_MOBILE
145 // Inject casts for automatic mixed precision
146 //
147 // TODO: Ideally, this pass could run earlier, before inlining
148 // or any other optimizations. That setup is preferable because:
149 // 1. The AMP pass would be self-contained and function independently
150 // of the any optimizations
151 // 2. AMP transformations would benefit from followup passes's cleanup
152 //
153 if (!disable_autocast) {
154 Autocast(graph);
155 }
156 #endif
157
158 ConstantPooling(graph);
159 }
160
tryToGraphFunction(Function & function)161 GraphFunction* tryToGraphFunction(Function& function) noexcept {
162 return tryToGraphFunctionImpl<GraphFunction>(function);
163 }
164
toGraphFunction(Function & function)165 GraphFunction& toGraphFunction(Function& function) {
166 return toGraphFunctionImpl<GraphFunction>(function);
167 }
168
toGraphFunction(const Function & function)169 const GraphFunction& toGraphFunction(const Function& function) {
170 return toGraphFunctionImpl<const GraphFunction>(function);
171 }
172
173 } // namespace torch::jit
174