1 #include <torch/csrc/jit/passes/create_functional_graphs.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7
8 #include <cstddef>
9 #include <limits>
10
11 namespace torch::jit {
12
13 namespace {
14
15 struct FunctionalGraphSlicer {
FunctionalGraphSlicertorch::jit::__anon9fa26a170111::FunctionalGraphSlicer16 FunctionalGraphSlicer(std::shared_ptr<Graph> graph)
17 : graph_(std::move(graph)) {}
18
runtorch::jit::__anon9fa26a170111::FunctionalGraphSlicer19 void run() {
20 bool changed = true;
21 // TODO: more sane strategy
22 size_t MAX_NUM_ITERATIONS = 4;
23
24 // First, analyze the functional subset of the graph, and then create
25 // functional graphs. The graph gets mutated when we create functional
26 // subgraphs, invalidating the AliasDb, so we need to do our analysis
27 // first.
28 for (size_t i = 0; i < MAX_NUM_ITERATIONS && changed; ++i) {
29 aliasDb_ = std::make_unique<AliasDb>(graph_);
30 AnalyzeFunctionalSubset(graph_->block());
31 changed = CreateFunctionalGraphsImpl(graph_->block());
32 }
33 }
34
35 private:
isEmptyFunctionalGraphtorch::jit::__anon9fa26a170111::FunctionalGraphSlicer36 bool isEmptyFunctionalGraph(Node* n) {
37 auto g = n->g(attr::Subgraph);
38 return g->inputs().empty() && g->outputs().empty();
39 }
40
nonConstNodestorch::jit::__anon9fa26a170111::FunctionalGraphSlicer41 void nonConstNodes(Block* block, size_t* num) {
42 for (auto it = block->nodes().begin();
43 it != block->nodes().end() && *num < minSubgraphSize_;
44 ++it) {
45 Node* n = *it;
46 if (n->kind() == prim::Constant) {
47 continue;
48 }
49 *num = *num + 1;
50 for (Block* b : n->blocks()) {
51 nonConstNodes(b, num);
52 }
53 }
54 }
55
inlineIfTooSmalltorch::jit::__anon9fa26a170111::FunctionalGraphSlicer56 bool inlineIfTooSmall(Node* n) {
57 AT_ASSERT(n->kind() == prim::FunctionalGraph);
58 auto subgraph = SubgraphUtils::getSubgraph(n);
59 size_t num_modes = 0;
60 nonConstNodes(subgraph->block(), &num_modes);
61 if (num_modes < minSubgraphSize_) {
62 SubgraphUtils::unmergeSubgraph(n);
63 return true;
64 }
65 return false;
66 }
67
CreateFunctionalGraphsImpltorch::jit::__anon9fa26a170111::FunctionalGraphSlicer68 bool CreateFunctionalGraphsImpl(Block* block) {
69 /*
70 Iterate the block in reverse and create FunctionalSubgraphs.
71 When we encounter a node that isn't functional, we skip it. Otherwise,
72 we try to merge the functional node into the current functional subgraph.
73 If it can't be merged into the current functional subgraph node, then we
74 start a functional subgraph group.
75 */
76 bool changed = false;
77 std::vector<Node*> functional_graph_nodes;
78
79 Node* functional_subgraph_node =
80 graph_->createWithSubgraph(prim::FunctionalGraph)
81 ->insertBefore(block->return_node());
82 auto reverse_iter = block->nodes().reverse();
83 std::vector<Value*> graph_outputs;
84 for (auto it = reverse_iter.begin(); it != reverse_iter.end();) {
85 Node* n = *it++;
86
87 // constants get copied into the graph
88 if (n->kind() == prim::Constant || n == functional_subgraph_node) {
89 continue;
90 }
91
92 // if `n` is functional, all of its blocks will be merged into the
93 // new functional subgraph, so we only need to recurse if it is not
94 // functional
95 if (!functional_nodes_.count(n)) {
96 for (Block* b : n->blocks()) {
97 auto block_changed = CreateFunctionalGraphsImpl(b);
98 changed = block_changed && changed;
99 }
100 continue;
101 }
102
103 if (n->kind() == prim::FunctionalGraph &&
104 isEmptyFunctionalGraph(functional_subgraph_node)) {
105 functional_subgraph_node->destroy();
106 functional_subgraph_node = n;
107 continue;
108 }
109
110 changed = true;
111 if (aliasDb_->moveBeforeTopologicallyValid(n, functional_subgraph_node)) {
112 SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
113 } else {
114 functional_graph_nodes.emplace_back(functional_subgraph_node);
115 functional_subgraph_node =
116 graph_->createWithSubgraph(prim::FunctionalGraph)->insertAfter(n);
117 SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
118 }
119 }
120 functional_graph_nodes.emplace_back(functional_subgraph_node);
121
122 for (Node* functional_node : functional_graph_nodes) {
123 if (!inlineIfTooSmall(functional_node)) {
124 ConstantPooling(functional_node->g(attr::Subgraph));
125 }
126 }
127 return changed;
128 }
129
AnalyzeFunctionalSubsettorch::jit::__anon9fa26a170111::FunctionalGraphSlicer130 bool AnalyzeFunctionalSubset(Node* n) {
131 // TODO: clarify hasSideEffects, isNondeterministic
132 bool is_functional_node = true;
133
134 // Functional Graphs are not responsible for maintaining aliasing
135 // relationships. If an output of a functional graph escapes scope
136 // or is mutated then we might change semantics of the program if
137 // aliasing relationships are changed.
138 // We don't allow any node in the functional graph to output a value
139 // that escapes scope or is mutated, and we don't allow any mutating nodes
140 // into the graph.
141 // - allow functional graphs to have at most one value that can escape scope
142 // - allow outputs which alias the wildcard set but do not "re-escape"
143 for (Value* v : n->outputs()) {
144 bool has_writers = aliasDb_->hasWriters(v);
145 bool escapes_scope = aliasDb_->escapesScope(v);
146 if (has_writers) {
147 mutated_values_.insert(v);
148 }
149 is_functional_node = is_functional_node && !escapes_scope && !has_writers;
150 }
151
152 for (Block* block : n->blocks()) {
153 auto functional_block = AnalyzeFunctionalSubset(block);
154 is_functional_node = is_functional_node && functional_block;
155 }
156
157 is_functional_node = is_functional_node && !aliasDb_->isMutable(n);
158 if (is_functional_node) {
159 functional_nodes_.insert(n);
160 }
161 return is_functional_node;
162 }
163
AnalyzeFunctionalSubsettorch::jit::__anon9fa26a170111::FunctionalGraphSlicer164 void AnalyzeFunctionalSubset(at::ArrayRef<Block*> blocks) {
165 for (Block* block : blocks) {
166 AnalyzeFunctionalSubset(block);
167 }
168 }
169
AnalyzeFunctionalSubsettorch::jit::__anon9fa26a170111::FunctionalGraphSlicer170 bool AnalyzeFunctionalSubset(Block* block) {
171 bool is_functional_block = true;
172 // block inputs will not yet have been iterated through,
173 // so we need to add them to our set of mutated & escape values.
174 for (Value* v : block->inputs()) {
175 bool has_writers = aliasDb_->hasWriters(v);
176 if (has_writers) {
177 mutated_values_.insert(v);
178 }
179 }
180 // if a block output is not functional, then the corresponding output for
181 // the node that contains the block will not be functional either, so we do
182 // not need to analyze the block outputs here.
183 for (Node* n : block->nodes()) {
184 bool functional = AnalyzeFunctionalSubset(n);
185 is_functional_block = is_functional_block && functional;
186 }
187 return is_functional_block;
188 }
189
190 std::unordered_set<Node*> functional_nodes_;
191 std::unordered_set<Value*> mutated_values_;
192 std::shared_ptr<Graph> graph_;
193 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
194 size_t minSubgraphSize_ = 6;
195 };
196
InlineFunctionalGraphs(Block * block)197 void InlineFunctionalGraphs(Block* block) {
198 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
199 Node* n = *it;
200 it++;
201 for (Block* b : n->blocks()) {
202 InlineFunctionalGraphs(b);
203 }
204 if (n->kind() == prim::FunctionalGraph) {
205 SubgraphUtils::unmergeSubgraph(n);
206 }
207 }
208 }
209
210 } // namespace
211
CreateFunctionalGraphs(const std::shared_ptr<Graph> & graph)212 void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
213 // Run Constant Pooling so constants get hoisted
214 ConstantPooling(graph);
215 FunctionalGraphSlicer func(graph);
216 func.run();
217 // Creation of Functional Subgraphs & Deinlining creates excess constants
218 ConstantPooling(graph);
219 }
220
InlineFunctionalGraphs(const std::shared_ptr<Graph> & graph)221 void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
222 InlineFunctionalGraphs(graph->block());
223 }
224
225 } // namespace torch::jit
226