1 #pragma once 2 3 #include <ATen/core/alias_info.h> 4 #include <c10/util/flat_hash_map.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 #include <torch/csrc/jit/ir/type_hashing.h> 7 #include <torch/csrc/jit/passes/create_functional_graphs.h> 8 #include <torch/csrc/jit/passes/utils/memory_dag.h> 9 10 namespace torch::jit { 11 12 /** 13 * Alias analysis pass. 14 * 15 * This pass produces an AliasDb that contains aliasing and mutation 16 * information about the graph. Users can use this information to determine 17 * whether mutations to the graph are safe, i.e. they don't reorder/change 18 * nodes in a way that affects output. 19 * 20 * Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be 21 * associated with one or more "alias sets". If two values share an alias set, 22 * that means they may alias, implying that a mutation to one value cannot be 23 * reordered past a use of the other. Only reordering two reads of an alias set 24 * is considered safe. 25 * 26 * There is a special alias set called the "wildcard set", which indicates that 27 * we're not sure what this value may alias. To be conservative, we consider the 28 * wildcard alias set as potentially aliasing any other wildcard value within 29 * the same type class. Whenever a value becomes contained by another value, 30 * such as when a Tensor is appended to a List[Tensor], the contained element 31 * becomes part of the wildcard set. 32 * 33 * Values that contain other mutable types, such as List[Tensor], are 34 * initialized as containing the Wildcard set for all contained mutable types. 35 * 36 * The AliasDb API references the idea of "mutable" vs "immutable" 37 * types. "Mutable" means that the object's value can change, while 38 * "immutable" means that the value is fixed. (For example, `List` is 39 * mutable, so you can add and delete elements from it. On the other 40 * hand, you can't modify a Tuple once you create it, making `Tuple` an 41 * immutable container.) 42 * 43 * `isFrozen` - if the Module is frozen then consider attributes as freshly 44 * created objects. Freezing API invokes alias analysis to check if they are 45 * mutated internally. 46 * 47 * `descendFunctionCalls` - recursively analyze function and method calls 48 * instead of conservative analysis. Generally analysis should be done after 49 * inlining so the implmentation for recursive analysis is unoptimized. 50 */ 51 class AliasDb { 52 public: 53 TORCH_API explicit AliasDb( 54 std::shared_ptr<Graph> graphi, 55 bool isFrozen = false, 56 bool descendFunctionCalls = false); 57 TORCH_API ~AliasDb(); 58 59 // There are limitations to what effects the alias analysis can track. Two 60 // kinds of nodes may have untracked effects: 61 // 1. Nodes that write to a value that may alias the graph inputs (since 62 // the inputs can be used outside the graph). 63 // 2. Nodes that write to something in the wildcard set. 64 // 65 // These nodes are considered not safe to eliminate or mutate under any 66 // circumstances. 67 bool writesToWildcard(Node* n) const; 68 69 // Does `n` write to an alias of one of the values in `vs`? 70 // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks 71 TORCH_API bool writesToAlias(Node* n, const ValueSet& vs) const; 72 73 // Does `a` and `b` potentially share a memory location or do either 74 // hold in memory any element that exists in the other 75 TORCH_API bool mayContainAlias(Value* a, Value* b) const; 76 77 TORCH_API bool mayContainAlias(Value* a, const at::ArrayRef<Value*> b) const; 78 79 // Do any values in group `a` share a memory location or hold in memory 80 // any element that exists in group `b` 81 TORCH_API bool mayContainAlias( 82 const at::ArrayRef<Value*> a, 83 const at::ArrayRef<Value*> b) const; 84 85 // Do `a` and `b` potentially share a memory location? 86 TORCH_API bool mayAlias(const Value* a, const Value* b) const; 87 // Do any values in group `a` potentially share a memory location with any 88 // value in group `b`? i.e. may they overlap? 89 TORCH_API bool mayAlias(const ValueSet& a, const ValueSet& b) const; 90 91 // Do any nodes write to an alias set input to `n`? 92 TORCH_API bool hasInputWriters(const Node* n) const; 93 94 // Do any nodes write to an alias set output by `n`? 95 TORCH_API bool hasOutputWriters(const Node* n) const; 96 97 // Do any nodes write to an alias set inputed/outputed by `n`? 98 TORCH_API bool hasWriters(const Node* n) const; 99 100 // Do any nodes write to `v`s memory location? 101 TORCH_API bool hasWriters(const Value* v) const; 102 103 // Is the operation in-place? i.e. doesn't write anywhere but locations it 104 // reads from. 105 TORCH_API bool isMutable(Node* n) const; 106 107 TORCH_API bool escapesScope(const at::ArrayRef<Value*>& vs) const; 108 109 // Is it safe to change whether `a` and `b` alias each other ? 110 TORCH_API bool safeToChangeAliasingRelationship( 111 const at::ArrayRef<Value*>& a, 112 const at::ArrayRef<Value*>& b) const; 113 114 // Move `n` (already in the graph) after `movePoint` in the topological order. 115 // 116 // Tries to preserve value dependencies, so other nodes might be moved. We 117 // make two guarantees about the postcondition of the node list: 118 // - `n` is directly after `movePoint`. 119 // - only nodes between `n` and `movePoint` have been moved. 120 // 121 // Returns `false` if it's impossible to move `n` after `MovePoint` without 122 // violating dependencies, otherwise executes the move and returns `true` 123 TORCH_API bool moveAfterTopologicallyValid(Node* n, Node* movePoint); 124 TORCH_API bool moveBeforeTopologicallyValid(Node* n, Node* movePoint); 125 126 bool couldMoveAfterTopologically(Node* n, Node* movePoint); 127 bool couldMoveBeforeTopologically(Node* n, Node* movePoint); 128 129 // For debugging: print alias db state to stdout 130 TORCH_API void dump() const; 131 TORCH_API std::string toString() const; 132 133 // Generates a DOT (www.graphviz.org) graph representation 134 // 135 // Returns `true` if the output file was successfully generated 136 // 137 // WARNING: The output dot file path can't include shell specific notations, 138 // for example you can't use "~/temp/aliasdb.dot" 139 // (instead, use "/home/user/temp/aliasdb.dot") 140 // 141 TORCH_API bool dumpToGraphvizFile(const char* filename) const; 142 TORCH_API std::string toGraphviz() const; 143 144 // Returns `true` if the given element is mutable or if it is a 145 // container type with an internal mutable element (e.g. 146 // `Tuple[int, Tensor]` has an internal mutable type `Tensor`, so 147 // it would be considered a "mutable type" in AliasDb) 148 static bool isMutableType(const Value* v); 149 static bool isMutableType(const TypePtr& type); 150 151 /** 152 * Mutation API 153 * 154 * These methods allow you to update AliasDb in-place if you are performing 155 * graph mutation. 156 * 157 * WARNING: These methods should be considered INTERNAL. They do not perform 158 * very many correctness checks, the user is responsible for making sure they 159 * are updating AliasDb correctly. `Lint()`ing the AliasDb can help with 160 * this. 161 */ 162 // Copy `existing`s aliasing info to `new_value`, and remove `existing`. 163 TORCH_API void replaceWithNewValue(Value* existing, Value* new_value); 164 // Copy `from`s aliasing info to `to`. 165 TORCH_API void copyValue(Value* from, Value* to); 166 // Create a new `value` that does not alias anything else. 167 TORCH_API void createValue(const Value* value); 168 169 // Enable more precise treatment of prim::TupleConstruct. 170 void enablePreciseTupleContainerAnalysis(); 171 172 friend struct MutationRemover; 173 174 private: 175 // Helper for topologically-safe node moves. 176 class WorkingSet; 177 enum class MoveSide { BEFORE, AFTER }; 178 bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun); 179 void move(Node* toMove, Node* movePoint, MoveSide moveSide); 180 bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; 181 182 bool isMutableTypeInternal(const Value* v) const; 183 bool isMutableTypeInternal(const TypePtr& type) const; 184 185 /** 186 * Write and read internal API 187 */ 188 // Get all the values that `n` writes to. 189 // NOTE: this only returns values directly written to, not aliases thereof 190 // 191 // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks 192 MemoryLocations getWrites(Node* n) const; 193 void getWritesImpl(Node* n, MemoryLocations& ret) const; 194 // Register the fact that `n` writes to `v`. 195 void registerWrite(const Value* v, Node* n, bool writeToContained = false); 196 // Get all the values that `n` reads from. 197 // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks 198 MemoryLocations getReads(Node* n) const; 199 void getReadsImpl(Node* n, MemoryLocations& ret) const; 200 201 /** 202 * Wildcard methods 203 */ 204 // Register `v` as a wildcard value. 205 std::optional<Element*> setWildcard(const Value* v); 206 207 // Is this a value which will not alias? 208 bool nonAliasingValue(const Value* elem) const; 209 210 /** 211 * Special analysis methods 212 */ 213 void analyze(const std::shared_ptr<Graph>& graph); 214 void analyze(Block* block); 215 void analyze(Node* node); 216 void analyzeImpl(Node* node); 217 void analyzeIf(Node* node); 218 void analyzeLoop(Node* node); 219 void analyzeSubgraph(Node* node, const std::shared_ptr<Graph>& subgraph); 220 void analyzeSubgraph(Node* node); 221 void analyzeCreator(Node* node); 222 void analyzeExtractor(Node* node); 223 void analyzeChunk(Node* node); 224 void analyzeBroadcastingChunk(Node* node); 225 void analyzeFork(Node* node); 226 void analyzeWait(Node* node); 227 void analyzeAwaitable(Node* node); 228 void analyzeAwaitableWait(Node* node); 229 void analyzeRpcAsync(Node* node); 230 void analyzeBatchNorm(Node* node); 231 void analyzeInstanceNorm(Node* node); 232 void analyzeGradOf(Node* node); 233 void analyzeSetAttr(Node* node); 234 void analyzeConservative(Node* node); 235 void analyzeContainerConstruct(Node* node); 236 bool tryRegisteredAnalysis(Node* node); 237 238 /** 239 * Alias manipulation methods 240 */ 241 void makeAllAlias(const std::vector<Value*>& values); 242 void makePointerTo(const Value* value, const Value* to); 243 TORCH_API void addToContainedElements( 244 const Value* element, 245 const Value* container); 246 void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from); 247 void giveFreshAlias( 248 const Value* value, 249 bool add_wildcard_to_contained_elems = true); 250 Element* getOrCreateElement(const Value* value); 251 252 const AliasTypeSet* mapTypeToAliasTypeSetPtr(const TypePtr& type) const; 253 bool functionalNonEscapingListUse(const Use& use) const; 254 bool functionalNonEscapingTupleUse(const Use& use) const; 255 256 std::shared_ptr<Graph> graph_; 257 258 // If the Module is frozen then consider attributes as freshly created 259 // objects. Freezing API invokes alias analysis to check if they are mutated 260 // internally. 261 bool isFrozen_; 262 263 bool descend_function_calls_; 264 std::unordered_map<Graph*, std::vector<std::shared_ptr<Graph>>> 265 function_call_copies_; 266 267 // The points-to graph that stores aliasing relationships 268 std::unique_ptr<MemoryDAGBuilder> memoryDAGBuilder_; 269 std::unique_ptr<MemoryDAG> memoryDAG_; 270 271 // Mapping of values to MemoryDAG elements 272 ska::flat_hash_map<const Value*, Element*> elementMap_; 273 // All wildcard Elements (one for each unique mutable type) 274 ska::flat_hash_map<TypePtr, Element*, HashType, EqualType> wildcardIndex_; 275 Element* getWildcard(const TypePtr& type) const; 276 std::optional<Element*> tryGetOrCreateWildcard(const TypePtr& type); 277 void addContainedTypesToFreshElement( 278 Element* container_elem, 279 const AliasTypeSet& mut_types); 280 void pointUnionTypeElementToAllContainedTypes( 281 Element* container_elem, 282 const AliasTypeSet& mut_types); 283 284 std::vector<Element*> getElements(at::ArrayRef<Value*> vs) const; 285 bool mayAliasWildcard(const Value* v) const; 286 bool mayAliasWildcard(const at::ArrayRef<Value*> vs) const; 287 bool hasWriters(const at::ArrayRef<Value*>& values) const; 288 289 // Cached mapping of type ptrs to their mutable types 290 mutable ska::flat_hash_map<TypePtr, AliasTypeSet> mapped_mutable_types_; 291 292 /** 293 * State for tracking write info. 294 */ 295 // Write registry where the analysis can record the writes as it sees them. 296 // This information is later denormalized into various caches to improve query 297 // efficiency. 298 struct WriteRegistry; 299 std::unique_ptr<WriteRegistry> writeRegistry_; 300 301 // Map of nodes to the memory locations that they write to 302 using TWriteIndex = ska::flat_hash_map<Node*, MemoryLocations>; 303 std::optional<TWriteIndex> writeIndex_; 304 // Collection of all memory locations that are written to. 305 std::optional<MemoryLocations> writtenToLocationsIndex_; 306 void buildWrittenToLocationsIndex(); 307 308 std::unordered_set<const Value*> wildcards_; 309 310 std::string getElementName(const Element* e) const; 311 312 friend void Lint(const AliasDb* db); 313 }; 314 315 // Helper check that invariants over AliasDb are maintained. 316 // Useful if you are using the AliasDb mutation API and want to check you did 317 // the right thing. 318 TORCH_API void Lint(const AliasDb* db); 319 320 } // namespace torch::jit 321