xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/alias_analysis.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/alias_info.h>
4 #include <c10/util/flat_hash_map.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/type_hashing.h>
7 #include <torch/csrc/jit/passes/create_functional_graphs.h>
8 #include <torch/csrc/jit/passes/utils/memory_dag.h>
9 
10 namespace torch::jit {
11 
12 /**
13  * Alias analysis pass.
14  *
15  * This pass produces an AliasDb that contains aliasing and mutation
16  * information about the graph. Users can use this information to determine
17  * whether mutations to the graph are safe, i.e. they don't reorder/change
18  * nodes in a way that affects output.
19  *
20  * Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be
21  * associated with one or more "alias sets". If two values share an alias set,
22  * that means they may alias, implying that a mutation to one value cannot be
23  * reordered past a use of the other. Only reordering two reads of an alias set
24  * is considered safe.
25  *
26  * There is a special alias set called the "wildcard set", which indicates that
27  * we're not sure what this value may alias. To be conservative, we consider the
28  * wildcard alias set as potentially aliasing any other wildcard value within
29  * the same type class. Whenever a value becomes contained by another value,
30  * such as when a Tensor is appended to a List[Tensor], the contained element
31  * becomes part of the wildcard set.
32  *
33  * Values that contain other mutable types, such as List[Tensor], are
34  * initialized as containing the Wildcard set for all contained mutable types.
35  *
36  * The AliasDb API references the idea of "mutable" vs "immutable"
37  * types. "Mutable" means that the object's value can change, while
38  * "immutable" means that the value is fixed. (For example, `List` is
39  * mutable, so you can add and delete elements from it. On the other
40  * hand, you can't modify a Tuple once you create it, making `Tuple` an
41  * immutable container.)
42  *
43  * `isFrozen` - if the Module is frozen then consider attributes as freshly
44  * created objects. Freezing API invokes alias analysis to check if they are
45  * mutated internally.
46  *
47  * `descendFunctionCalls` - recursively analyze function and method calls
48  * instead of conservative analysis. Generally analysis should be done after
49  * inlining so the implmentation for recursive analysis is unoptimized.
50  */
51 class AliasDb {
52  public:
53   TORCH_API explicit AliasDb(
54       std::shared_ptr<Graph> graphi,
55       bool isFrozen = false,
56       bool descendFunctionCalls = false);
57   TORCH_API ~AliasDb();
58 
59   // There are limitations to what effects the alias analysis can track. Two
60   // kinds of nodes may have untracked effects:
61   // 1. Nodes that write to a value that may alias the graph inputs (since
62   //    the inputs can be used outside the graph).
63   // 2. Nodes that write to something in the wildcard set.
64   //
65   // These nodes are considered not safe to eliminate or mutate under any
66   // circumstances.
67   bool writesToWildcard(Node* n) const;
68 
69   // Does `n` write to an alias of one of the values in `vs`?
70   // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
71   TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const;
72 
73   // Does `a` and `b` potentially share a memory location or do either
74   // hold in memory any element that exists in the other
75   TORCH_API bool mayContainAlias(Value* a, Value* b) const;
76 
77   TORCH_API bool mayContainAlias(Value* a, const at::ArrayRef<Value*> b) const;
78 
79   // Do any values in group `a` share a memory location or hold in memory
80   // any element that exists in group `b`
81   TORCH_API bool mayContainAlias(
82       const at::ArrayRef<Value*> a,
83       const at::ArrayRef<Value*> b) const;
84 
85   // Do `a` and `b` potentially share a memory location?
86   TORCH_API bool mayAlias(const Value* a, const Value* b) const;
87   // Do any values in group `a` potentially share a memory location with any
88   // value in group `b`? i.e. may they overlap?
89   TORCH_API bool mayAlias(const ValueSet& a, const ValueSet& b) const;
90 
91   // Do any nodes write to an alias set input to `n`?
92   TORCH_API bool hasInputWriters(const Node* n) const;
93 
94   // Do any nodes write to an alias set output by `n`?
95   TORCH_API bool hasOutputWriters(const Node* n) const;
96 
97   // Do any nodes write to an alias set inputed/outputed by `n`?
98   TORCH_API bool hasWriters(const Node* n) const;
99 
100   // Do any nodes write to `v`s memory location?
101   TORCH_API bool hasWriters(const Value* v) const;
102 
103   // Is the operation in-place? i.e. doesn't write anywhere but locations it
104   // reads from.
105   TORCH_API bool isMutable(Node* n) const;
106 
107   TORCH_API bool escapesScope(const at::ArrayRef<Value*>& vs) const;
108 
109   // Is it safe to change whether `a` and `b` alias each other ?
110   TORCH_API bool safeToChangeAliasingRelationship(
111       const at::ArrayRef<Value*>& a,
112       const at::ArrayRef<Value*>& b) const;
113 
114   // Move `n` (already in the graph) after `movePoint` in the topological order.
115   //
116   // Tries to preserve value dependencies, so other nodes might be moved. We
117   // make two guarantees about the postcondition of the node list:
118   //   - `n` is directly after `movePoint`.
119   //   - only nodes between `n` and `movePoint` have been moved.
120   //
121   // Returns `false` if it's impossible to move `n` after `MovePoint` without
122   // violating dependencies, otherwise executes the move and returns `true`
123   TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
124   TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
125 
126   bool couldMoveAfterTopologically(Node* n, Node* movePoint);
127   bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
128 
129   // For debugging: print alias db state to stdout
130   TORCH_API void dump() const;
131   TORCH_API std::string toString() const;
132 
133   // Generates a DOT (www.graphviz.org) graph representation
134   //
135   // Returns `true` if the output file was successfully generated
136   //
137   // WARNING: The output dot file path can't include shell specific notations,
138   //  for example you can't use "~/temp/aliasdb.dot"
139   //  (instead, use "/home/user/temp/aliasdb.dot")
140   //
141   TORCH_API bool dumpToGraphvizFile(const char* filename) const;
142   TORCH_API std::string toGraphviz() const;
143 
144   // Returns `true` if the given element is mutable or if it is a
145   // container type with an internal mutable element (e.g.
146   // `Tuple[int, Tensor]` has an internal mutable type `Tensor`, so
147   // it would be considered a "mutable type" in AliasDb)
148   static bool isMutableType(const Value* v);
149   static bool isMutableType(const TypePtr& type);
150 
151   /**
152    * Mutation API
153    *
154    * These methods allow you to update AliasDb in-place if you are performing
155    * graph mutation.
156    *
157    * WARNING: These methods should be considered INTERNAL. They do not perform
158    * very many correctness checks, the user is responsible for making sure they
159    * are updating AliasDb correctly. `Lint()`ing the AliasDb can help with
160    * this.
161    */
162   // Copy `existing`s aliasing info to `new_value`, and remove `existing`.
163   TORCH_API void replaceWithNewValue(Value* existing, Value* new_value);
164   // Copy `from`s aliasing info to `to`.
165   TORCH_API void copyValue(Value* from, Value* to);
166   // Create a new `value` that does not alias anything else.
167   TORCH_API void createValue(const Value* value);
168 
169   // Enable more precise treatment of prim::TupleConstruct.
170   void enablePreciseTupleContainerAnalysis();
171 
172   friend struct MutationRemover;
173 
174  private:
175   // Helper for topologically-safe node moves.
176   class WorkingSet;
177   enum class MoveSide { BEFORE, AFTER };
178   bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun);
179   void move(Node* toMove, Node* movePoint, MoveSide moveSide);
180   bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
181 
182   bool isMutableTypeInternal(const Value* v) const;
183   bool isMutableTypeInternal(const TypePtr& type) const;
184 
185   /**
186    * Write and read internal API
187    */
188   // Get all the values that `n` writes to.
189   // NOTE: this only returns values directly written to, not aliases thereof
190   //
191   // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
192   MemoryLocations getWrites(Node* n) const;
193   void getWritesImpl(Node* n, MemoryLocations& ret) const;
194   // Register the fact that `n` writes to `v`.
195   void registerWrite(const Value* v, Node* n, bool writeToContained = false);
196   // Get all the values that `n` reads from.
197   // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
198   MemoryLocations getReads(Node* n) const;
199   void getReadsImpl(Node* n, MemoryLocations& ret) const;
200 
201   /**
202    * Wildcard methods
203    */
204   // Register `v` as a wildcard value.
205   std::optional<Element*> setWildcard(const Value* v);
206 
207   // Is this a value which will not alias?
208   bool nonAliasingValue(const Value* elem) const;
209 
210   /**
211    * Special analysis methods
212    */
213   void analyze(const std::shared_ptr<Graph>& graph);
214   void analyze(Block* block);
215   void analyze(Node* node);
216   void analyzeImpl(Node* node);
217   void analyzeIf(Node* node);
218   void analyzeLoop(Node* node);
219   void analyzeSubgraph(Node* node, const std::shared_ptr<Graph>& subgraph);
220   void analyzeSubgraph(Node* node);
221   void analyzeCreator(Node* node);
222   void analyzeExtractor(Node* node);
223   void analyzeChunk(Node* node);
224   void analyzeBroadcastingChunk(Node* node);
225   void analyzeFork(Node* node);
226   void analyzeWait(Node* node);
227   void analyzeAwaitable(Node* node);
228   void analyzeAwaitableWait(Node* node);
229   void analyzeRpcAsync(Node* node);
230   void analyzeBatchNorm(Node* node);
231   void analyzeInstanceNorm(Node* node);
232   void analyzeGradOf(Node* node);
233   void analyzeSetAttr(Node* node);
234   void analyzeConservative(Node* node);
235   void analyzeContainerConstruct(Node* node);
236   bool tryRegisteredAnalysis(Node* node);
237 
238   /**
239    * Alias manipulation methods
240    */
241   void makeAllAlias(const std::vector<Value*>& values);
242   void makePointerTo(const Value* value, const Value* to);
243   TORCH_API void addToContainedElements(
244       const Value* element,
245       const Value* container);
246   void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
247   void giveFreshAlias(
248       const Value* value,
249       bool add_wildcard_to_contained_elems = true);
250   Element* getOrCreateElement(const Value* value);
251 
252   const AliasTypeSet* mapTypeToAliasTypeSetPtr(const TypePtr& type) const;
253   bool functionalNonEscapingListUse(const Use& use) const;
254   bool functionalNonEscapingTupleUse(const Use& use) const;
255 
256   std::shared_ptr<Graph> graph_;
257 
258   // If the Module is frozen then consider attributes as freshly created
259   // objects. Freezing API invokes alias analysis to check if they are mutated
260   // internally.
261   bool isFrozen_;
262 
263   bool descend_function_calls_;
264   std::unordered_map<Graph*, std::vector<std::shared_ptr<Graph>>>
265       function_call_copies_;
266 
267   // The points-to graph that stores aliasing relationships
268   std::unique_ptr<MemoryDAGBuilder> memoryDAGBuilder_;
269   std::unique_ptr<MemoryDAG> memoryDAG_;
270 
271   // Mapping of values to MemoryDAG elements
272   ska::flat_hash_map<const Value*, Element*> elementMap_;
273   // All wildcard Elements (one for each unique mutable type)
274   ska::flat_hash_map<TypePtr, Element*, HashType, EqualType> wildcardIndex_;
275   Element* getWildcard(const TypePtr& type) const;
276   std::optional<Element*> tryGetOrCreateWildcard(const TypePtr& type);
277   void addContainedTypesToFreshElement(
278       Element* container_elem,
279       const AliasTypeSet& mut_types);
280   void pointUnionTypeElementToAllContainedTypes(
281       Element* container_elem,
282       const AliasTypeSet& mut_types);
283 
284   std::vector<Element*> getElements(at::ArrayRef<Value*> vs) const;
285   bool mayAliasWildcard(const Value* v) const;
286   bool mayAliasWildcard(const at::ArrayRef<Value*> vs) const;
287   bool hasWriters(const at::ArrayRef<Value*>& values) const;
288 
289   // Cached mapping of type ptrs to their mutable types
290   mutable ska::flat_hash_map<TypePtr, AliasTypeSet> mapped_mutable_types_;
291 
292   /**
293    * State for tracking write info.
294    */
295   // Write registry where the analysis can record the writes as it sees them.
296   // This information is later denormalized into various caches to improve query
297   // efficiency.
298   struct WriteRegistry;
299   std::unique_ptr<WriteRegistry> writeRegistry_;
300 
301   // Map of nodes to the memory locations that they write to
302   using TWriteIndex = ska::flat_hash_map<Node*, MemoryLocations>;
303   std::optional<TWriteIndex> writeIndex_;
304   // Collection of all memory locations that are written to.
305   std::optional<MemoryLocations> writtenToLocationsIndex_;
306   void buildWrittenToLocationsIndex();
307 
308   std::unordered_set<const Value*> wildcards_;
309 
310   std::string getElementName(const Element* e) const;
311 
312   friend void Lint(const AliasDb* db);
313 };
314 
315 // Helper check that invariants over AliasDb are maintained.
316 // Useful if you are using the AliasDb mutation API and want to check you did
317 // the right thing.
318 TORCH_API void Lint(const AliasDb* db);
319 
320 } // namespace torch::jit
321