xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/create_functional_graphs.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/create_functional_graphs.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7 
8 #include <cstddef>
9 #include <limits>
10 
11 namespace torch::jit {
12 
13 namespace {
14 
15 struct FunctionalGraphSlicer {
FunctionalGraphSlicertorch::jit::__anon9fa26a170111::FunctionalGraphSlicer16   FunctionalGraphSlicer(std::shared_ptr<Graph> graph)
17       : graph_(std::move(graph)) {}
18 
runtorch::jit::__anon9fa26a170111::FunctionalGraphSlicer19   void run() {
20     bool changed = true;
21     // TODO: more sane strategy
22     size_t MAX_NUM_ITERATIONS = 4;
23 
24     // First, analyze the functional subset of the graph, and then create
25     // functional graphs. The graph gets mutated when we create functional
26     // subgraphs, invalidating the AliasDb, so we need to do our analysis
27     // first.
28     for (size_t i = 0; i < MAX_NUM_ITERATIONS && changed; ++i) {
29       aliasDb_ = std::make_unique<AliasDb>(graph_);
30       AnalyzeFunctionalSubset(graph_->block());
31       changed = CreateFunctionalGraphsImpl(graph_->block());
32     }
33   }
34 
35  private:
isEmptyFunctionalGraphtorch::jit::__anon9fa26a170111::FunctionalGraphSlicer36   bool isEmptyFunctionalGraph(Node* n) {
37     auto g = n->g(attr::Subgraph);
38     return g->inputs().empty() && g->outputs().empty();
39   }
40 
nonConstNodestorch::jit::__anon9fa26a170111::FunctionalGraphSlicer41   void nonConstNodes(Block* block, size_t* num) {
42     for (auto it = block->nodes().begin();
43          it != block->nodes().end() && *num < minSubgraphSize_;
44          ++it) {
45       Node* n = *it;
46       if (n->kind() == prim::Constant) {
47         continue;
48       }
49       *num = *num + 1;
50       for (Block* b : n->blocks()) {
51         nonConstNodes(b, num);
52       }
53     }
54   }
55 
inlineIfTooSmalltorch::jit::__anon9fa26a170111::FunctionalGraphSlicer56   bool inlineIfTooSmall(Node* n) {
57     AT_ASSERT(n->kind() == prim::FunctionalGraph);
58     auto subgraph = SubgraphUtils::getSubgraph(n);
59     size_t num_modes = 0;
60     nonConstNodes(subgraph->block(), &num_modes);
61     if (num_modes < minSubgraphSize_) {
62       SubgraphUtils::unmergeSubgraph(n);
63       return true;
64     }
65     return false;
66   }
67 
CreateFunctionalGraphsImpltorch::jit::__anon9fa26a170111::FunctionalGraphSlicer68   bool CreateFunctionalGraphsImpl(Block* block) {
69     /*
70     Iterate the block in reverse and create FunctionalSubgraphs.
71     When we encounter a node that isn't functional, we skip it. Otherwise,
72     we try to merge the functional node into the current functional subgraph.
73     If it can't be merged into the current functional subgraph node, then we
74     start a functional subgraph group.
75     */
76     bool changed = false;
77     std::vector<Node*> functional_graph_nodes;
78 
79     Node* functional_subgraph_node =
80         graph_->createWithSubgraph(prim::FunctionalGraph)
81             ->insertBefore(block->return_node());
82     auto reverse_iter = block->nodes().reverse();
83     std::vector<Value*> graph_outputs;
84     for (auto it = reverse_iter.begin(); it != reverse_iter.end();) {
85       Node* n = *it++;
86 
87       // constants get copied into the graph
88       if (n->kind() == prim::Constant || n == functional_subgraph_node) {
89         continue;
90       }
91 
92       // if `n` is functional, all of its blocks will be merged into the
93       // new functional subgraph, so we only need to recurse if it is not
94       // functional
95       if (!functional_nodes_.count(n)) {
96         for (Block* b : n->blocks()) {
97           auto block_changed = CreateFunctionalGraphsImpl(b);
98           changed = block_changed && changed;
99         }
100         continue;
101       }
102 
103       if (n->kind() == prim::FunctionalGraph &&
104           isEmptyFunctionalGraph(functional_subgraph_node)) {
105         functional_subgraph_node->destroy();
106         functional_subgraph_node = n;
107         continue;
108       }
109 
110       changed = true;
111       if (aliasDb_->moveBeforeTopologicallyValid(n, functional_subgraph_node)) {
112         SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
113       } else {
114         functional_graph_nodes.emplace_back(functional_subgraph_node);
115         functional_subgraph_node =
116             graph_->createWithSubgraph(prim::FunctionalGraph)->insertAfter(n);
117         SubgraphUtils::mergeNodeIntoSubgraph(n, functional_subgraph_node);
118       }
119     }
120     functional_graph_nodes.emplace_back(functional_subgraph_node);
121 
122     for (Node* functional_node : functional_graph_nodes) {
123       if (!inlineIfTooSmall(functional_node)) {
124         ConstantPooling(functional_node->g(attr::Subgraph));
125       }
126     }
127     return changed;
128   }
129 
AnalyzeFunctionalSubsettorch::jit::__anon9fa26a170111::FunctionalGraphSlicer130   bool AnalyzeFunctionalSubset(Node* n) {
131     // TODO: clarify hasSideEffects, isNondeterministic
132     bool is_functional_node = true;
133 
134     // Functional Graphs are not responsible for maintaining aliasing
135     // relationships. If an output of a functional graph escapes scope
136     // or is mutated then we might change semantics of the program if
137     // aliasing relationships are changed.
138     // We don't allow any node in the functional graph to output a value
139     // that escapes scope or is mutated, and we don't allow any mutating nodes
140     // into the graph.
141     // - allow functional graphs to have at most one value that can escape scope
142     // - allow outputs which alias the wildcard set but do not "re-escape"
143     for (Value* v : n->outputs()) {
144       bool has_writers = aliasDb_->hasWriters(v);
145       bool escapes_scope = aliasDb_->escapesScope(v);
146       if (has_writers) {
147         mutated_values_.insert(v);
148       }
149       is_functional_node = is_functional_node && !escapes_scope && !has_writers;
150     }
151 
152     for (Block* block : n->blocks()) {
153       auto functional_block = AnalyzeFunctionalSubset(block);
154       is_functional_node = is_functional_node && functional_block;
155     }
156 
157     is_functional_node = is_functional_node && !aliasDb_->isMutable(n);
158     if (is_functional_node) {
159       functional_nodes_.insert(n);
160     }
161     return is_functional_node;
162   }
163 
AnalyzeFunctionalSubsettorch::jit::__anon9fa26a170111::FunctionalGraphSlicer164   void AnalyzeFunctionalSubset(at::ArrayRef<Block*> blocks) {
165     for (Block* block : blocks) {
166       AnalyzeFunctionalSubset(block);
167     }
168   }
169 
AnalyzeFunctionalSubsettorch::jit::__anon9fa26a170111::FunctionalGraphSlicer170   bool AnalyzeFunctionalSubset(Block* block) {
171     bool is_functional_block = true;
172     // block inputs will not yet have been iterated through,
173     // so we need to add them to our set of mutated & escape values.
174     for (Value* v : block->inputs()) {
175       bool has_writers = aliasDb_->hasWriters(v);
176       if (has_writers) {
177         mutated_values_.insert(v);
178       }
179     }
180     // if a block output is not functional, then the corresponding output for
181     // the node that contains the block will not be functional either, so we do
182     // not need to analyze the block outputs here.
183     for (Node* n : block->nodes()) {
184       bool functional = AnalyzeFunctionalSubset(n);
185       is_functional_block = is_functional_block && functional;
186     }
187     return is_functional_block;
188   }
189 
190   std::unordered_set<Node*> functional_nodes_;
191   std::unordered_set<Value*> mutated_values_;
192   std::shared_ptr<Graph> graph_;
193   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
194   size_t minSubgraphSize_ = 6;
195 };
196 
InlineFunctionalGraphs(Block * block)197 void InlineFunctionalGraphs(Block* block) {
198   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
199     Node* n = *it;
200     it++;
201     for (Block* b : n->blocks()) {
202       InlineFunctionalGraphs(b);
203     }
204     if (n->kind() == prim::FunctionalGraph) {
205       SubgraphUtils::unmergeSubgraph(n);
206     }
207   }
208 }
209 
210 } // namespace
211 
CreateFunctionalGraphs(const std::shared_ptr<Graph> & graph)212 void CreateFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
213   // Run Constant Pooling so constants get hoisted
214   ConstantPooling(graph);
215   FunctionalGraphSlicer func(graph);
216   func.run();
217   // Creation of Functional Subgraphs & Deinlining creates excess constants
218   ConstantPooling(graph);
219 }
220 
InlineFunctionalGraphs(const std::shared_ptr<Graph> & graph)221 void InlineFunctionalGraphs(const std::shared_ptr<Graph>& graph) {
222   InlineFunctionalGraphs(graph->block());
223 }
224 
225 } // namespace torch::jit
226