1 #pragma once 2 3 #include <ATen/core/jit_type.h> 4 #include <c10/util/ArrayRef.h> 5 #include <c10/util/flat_hash_map.h> 6 #include <c10/util/sparse_bitset.h> 7 #include <torch/csrc/jit/ir/ir.h> 8 #include <torch/csrc/jit/ir/type_hashing.h> 9 #include <memory> 10 #include <optional> 11 #include <unordered_map> 12 #include <unordered_set> 13 #include <vector> 14 15 #include <torch/csrc/Export.h> 16 17 // Uses a compressed index representation for faster comparisons 18 typedef c10::SparseBitVector<256> MemoryLocations; 19 namespace torch { 20 namespace jit { 21 22 struct Value; 23 24 using AliasTypeSet = std::vector<TypePtr>; 25 26 // `Element` represents a vertex in the points-to graph. It represents 27 // anything that could have an aliasing relationship--mostly IR 28 // `Value`s, but also wildcards or the type inside a container (e.g. `T` 29 // in `List[T]`) 30 struct Element { 31 Element(const Value* value_, unsigned index_); 32 // wildcard constructor 33 explicit Element(unsigned index_); 34 35 // Index into the owning DAG's bit vector that represents this element. 36 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 37 unsigned index; 38 39 // All elements that this element *may* point to. It's possible to have 40 // multiple elements that you might point to due to control flow/complex ops 41 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 42 MemoryLocations pointsTo; 43 // Backreference for points-to. 44 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 45 MemoryLocations pointedFrom; 46 47 // Elements can contain other elements (e.g. List[Tensor]) 48 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 49 MemoryLocations containedElements; 50 51 // The values that this element corresponds to. May be empty if this element 52 // doesn't represent a first-class value. 53 // This is for debug information only. 54 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 55 std::unordered_set<const Value*> values; 56 57 private: 58 // Make `from` point at `to`. 59 void makePointerTo(Element* from, Element* to); 60 61 friend class MemoryDAG; 62 // We memoize the results of `getMemoryLocations` to speed up queries. 63 // A nullopt means that this cache is not yet populated. Since `MemoryDAG` is 64 // immutable, this cache should never need to be invalidated. 65 mutable std::optional<MemoryLocations> cachedMemoryLocations_; 66 67 mutable std::optional<MemoryLocations> cachedAllContainedMemoryLocations_; 68 }; 69 70 // class MemoryDAG 71 // 72 // This class tracks the "A points to B" graph for all values. It is used by 73 // AliasDb to provide a higher-level API. 74 // 75 // We maintain a DAG where: 76 // - Vertices (called "Elements") represent Values and 77 // other aliasing entities (e.g. the stuff inside a list) 78 // - Edges represent a "points-to" relationship. 79 // 80 // Leaves in this DAG are entities that don't point to anything, and thus 81 // correspond to unique "memory locations". 82 // 83 // So, by traversing the "points-to" graph to the leaves, you can determine 84 // which memory locations an element may point to. 85 class TORCH_API MemoryDAG { 86 public: MemoryDAG(std::vector<std::unique_ptr<Element>> indexToElementMap)87 explicit MemoryDAG(std::vector<std::unique_ptr<Element>> indexToElementMap) 88 : indexToElementMap_(std::move(indexToElementMap)) {} 89 // explicitly delete copy constructor because otherwise windows build is 90 // confused for an exported class see 91 // https://stackoverflow.com/a/51033485/105137 92 MemoryDAG(const MemoryDAG&) = delete; 93 MemoryDAG& operator=(const MemoryDAG&) = delete; 94 95 // Return the unique memory locations that `Element` might represent. 96 const MemoryLocations& getMemoryLocations(const Element* e) const; 97 98 // Do `a` and `b` potentially share a memory location? 99 bool mayAlias(const Element* a, const Element* b) const; 100 101 // Does `a` hold reference to any memory that is stored in `b`, or vice versa? 102 bool mayContainAlias(const Element* a, const Element* b) const; 103 104 bool mayContainAlias(const Element* a, const at::ArrayRef<Element*> b) const; 105 106 bool mayContainAlias( 107 const at::ArrayRef<Element*> a, 108 const at::ArrayRef<Element*> b) const; 109 110 // Converts from the compressed index representation 111 const Element* fromIndex(unsigned x) const; 112 Element* fromIndex(unsigned x); 113 void collectAllContainedMemoryLocations( 114 const Element* elem, 115 MemoryLocations& cont) const; 116 117 /** 118 * The following methods are special cases where we need to mutate the 119 * internals of MemoryDAG for efficiency reasons. Don't call them unless you 120 * know what you're doing! In particular, don't add new mutating methods 121 * without ensuring that you are maintaining cache consistency for memory 122 * locations. 123 */ 124 125 // Adding wildcards can trigger extremely expensive cache invalidations. This 126 // method adds them in a more efficient cache-aware way. 127 void setWildcards( 128 const std::unordered_set<const Value*>& wildcards, 129 const ska::flat_hash_map<const Value*, Element*>& elementMap, 130 const std::function<Element*(const Value*)>& getWildcardElement); 131 Element* unsafeMakeFreshValue(const Value* v); 132 133 private: 134 const MemoryLocations& getAllContainedMemoryLocations( 135 const Element* elem) const; 136 void collectAllContainedMemoryLocationsImpl( 137 const Element* elem, 138 MemoryLocations& cont) const; 139 std::vector<std::unique_ptr<Element>> indexToElementMap_; 140 }; 141 142 /** 143 * Helper to build up the points-to graph. 144 * 145 * We separate the "building" into a different class because it allows us to 146 * cache internally to MemoryDAG without worrying about how the DAG structure 147 * is mutated. 148 */ 149 class TORCH_API MemoryDAGBuilder { 150 public: 151 MemoryDAGBuilder() = default; 152 MemoryDAGBuilder(const MemoryDAGBuilder&) = delete; 153 MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete; 154 155 // Make `from` point at `to`. 156 void makePointerTo(Element* from, Element* to); 157 158 void addToContainedElements(Element* contained, Element* container); 159 createMemoryDAG()160 std::unique_ptr<MemoryDAG> createMemoryDAG() && { 161 return std::make_unique<MemoryDAG>(std::move(indexToElementMap_)); 162 } 163 164 // Make a fresh Element (i.e. an Element that doesn't point to anything) and 165 // return it. 166 Element* makeFreshValue(const Value* v); 167 168 friend MemoryDAG; 169 170 private: 171 // `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses 172 // the map to construct the `MemoryDAG` 173 std::vector<std::unique_ptr<Element>> indexToElementMap_; 174 }; 175 } // namespace jit 176 } // namespace torch 177