1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 2 // -*- mode: C++ -*- 3 // 4 // Copyright 2022 Google LLC 5 // 6 // Licensed under the Apache License v2.0 with LLVM Exceptions (the 7 // "License"); you may not use this file except in compliance with the 8 // License. You may obtain a copy of the License at 9 // 10 // https://llvm.org/LICENSE.txt 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 // 18 // Author: Giuliano Procida 19 20 #ifndef STG_EQUALITY_CACHE_H_ 21 #define STG_EQUALITY_CACHE_H_ 22 23 #include <cstddef> 24 #include <optional> 25 #include <unordered_map> 26 #include <unordered_set> 27 #include <utility> 28 #include <vector> 29 30 #include "graph.h" 31 #include "hashing.h" 32 #include "runtime.h" 33 34 namespace stg { 35 36 // Equality cache - for use with the Equals function object 37 // 38 // This supports many features, some of probably limited long-term utility. 39 // 40 // It caches equalities (symmetrically) using union-find with path halving and 41 // union by rank. 42 // 43 // It caches inequalities (symmetrically); the inequalities are updated as part 44 // of the union operation. 45 // 46 // Node hashes such as those generated by the Fingerprint function object may be 47 // supplied to avoid equality testing when hashes differ. 48 struct EqualityCache { EqualityCacheEqualityCache49 EqualityCache(Runtime& runtime, 50 const std::unordered_map<Id, HashValue>& hashes) 51 : hashes(hashes), 52 query_count(runtime, "cache.query_count"), 53 query_equal_ids(runtime, "cache.query_equal_ids"), 54 query_unequal_hashes(runtime, "cache.query_unequal_hashes"), 55 query_equal_representatives(runtime, 56 "cache.query_equal_representatives"), 57 query_inequality_found(runtime, "cache.query_inequality_found"), 58 query_not_found(runtime, "cache.query_not_found"), 59 find_halved(runtime, "cache.find_halved"), 60 union_known(runtime, "cache.union_known"), 61 union_rank_swap(runtime, "cache.union_rank_swap"), 62 union_rank_increase(runtime, "cache.union_rank_increase"), 63 union_rank_zero(runtime, "cache.union_rank_zero"), 64 union_unknown(runtime, "cache.union_unknown"), 65 disunion_known_hash(runtime, "cache.disunion_known_hash"), 66 disunion_known_inequality(runtime, "cache.disunion_known_inequality"), 67 disunion_unknown(runtime, "cache.disunion_unknown") {} 68 QueryEqualityCache69 std::optional<bool> Query(const Pair& comparison) { 70 ++query_count; 71 const auto& [id1, id2] = comparison; 72 if (id1 == id2) { 73 ++query_equal_ids; 74 return std::make_optional(true); 75 } 76 if (DistinctHashes(id1, id2)) { 77 ++query_unequal_hashes; 78 return std::make_optional(false); 79 } 80 const Id fid1 = Find(id1); 81 const Id fid2 = Find(id2); 82 if (fid1 == fid2) { 83 ++query_equal_representatives; 84 return std::make_optional(true); 85 } 86 auto not_it = inequalities.find(fid1); 87 if (not_it != inequalities.end()) { 88 auto not_it2 = not_it->second.find(fid2); 89 if (not_it2 != not_it->second.end()) { 90 ++query_inequality_found; 91 return std::make_optional(false); 92 } 93 } 94 ++query_not_found; 95 return std::nullopt; 96 } 97 AllSameEqualityCache98 void AllSame(const std::vector<Pair>& comparisons) { 99 for (const auto& [id1, id2] : comparisons) { 100 Union(id1, id2); 101 } 102 } 103 AllDifferentEqualityCache104 void AllDifferent(const std::vector<Pair>& comparisons) { 105 for (const auto& [id1, id2] : comparisons) { 106 Disunion(id1, id2); 107 } 108 } 109 DistinctHashesEqualityCache110 bool DistinctHashes(Id id1, Id id2) { 111 const auto it1 = hashes.find(id1); 112 const auto it2 = hashes.find(id2); 113 return it1 != hashes.end() && it2 != hashes.end() 114 && it1->second != it2->second; 115 } 116 FindEqualityCache117 Id Find(Id id) { 118 // path halving 119 while (true) { 120 auto it = mapping.find(id); 121 if (it == mapping.end()) { 122 return id; 123 } 124 auto& parent = it->second; 125 auto parent_it = mapping.find(parent); 126 if (parent_it == mapping.end()) { 127 return parent; 128 } 129 auto parent_parent = parent_it->second; 130 id = parent = parent_parent; 131 ++find_halved; 132 } 133 } 134 GetRankEqualityCache135 size_t GetRank(Id id) { 136 auto it = rank.find(id); 137 return it == rank.end() ? 0 : it->second; 138 } 139 SetRankEqualityCache140 void SetRank(Id id, size_t r) { 141 if (r) { 142 rank[id] = r; 143 } else { 144 rank.erase(id); 145 } 146 } 147 UnionEqualityCache148 void Union(Id id1, Id id2) { 149 Check(!DistinctHashes(id1, id2)) << "union with distinct hashes"; 150 Id fid1 = Find(id1); 151 Id fid2 = Find(id2); 152 if (fid1 == fid2) { 153 ++union_known; 154 return; 155 } 156 size_t rank1 = GetRank(fid1); 157 size_t rank2 = GetRank(fid2); 158 if (rank1 > rank2) { 159 std::swap(fid1, fid2); 160 std::swap(rank1, rank2); 161 ++union_rank_swap; 162 } 163 // rank1 <= rank2 164 if (rank1 == rank2) { 165 SetRank(fid2, rank2 + 1); 166 ++union_rank_increase; 167 } 168 if (rank1) { 169 SetRank(fid1, 0); 170 ++union_rank_zero; 171 } 172 mapping.insert({fid1, fid2}); 173 ++union_unknown; 174 175 // move inequalities from fid1 to fid2 176 auto not_it = inequalities.find(fid1); 177 if (not_it != inequalities.end()) { 178 auto& source = not_it->second; 179 auto& target = inequalities[fid2]; 180 for (auto fid : source) { 181 Check(fid != fid2) << "union of unequal"; 182 target.insert(fid); 183 auto& target2 = inequalities[fid]; 184 target2.erase(fid1); 185 target2.insert(fid2); 186 } 187 } 188 } 189 DisunionEqualityCache190 void Disunion(Id id1, Id id2) { 191 if (DistinctHashes(id1, id2)) { 192 ++disunion_known_hash; 193 return; 194 } 195 const Id fid1 = Find(id1); 196 const Id fid2 = Find(id2); 197 Check(fid1 != fid2) << "disunion of equal"; 198 if (inequalities[fid1].insert(fid2).second) { 199 inequalities[fid2].insert(fid1); 200 ++disunion_unknown; 201 } else { 202 ++disunion_known_inequality; 203 } 204 } 205 206 const std::unordered_map<Id, HashValue>& hashes; 207 std::unordered_map<Id, Id> mapping; 208 std::unordered_map<Id, size_t> rank; 209 std::unordered_map<Id, std::unordered_set<Id>> inequalities; 210 211 Counter query_count; 212 Counter query_equal_ids; 213 Counter query_unequal_hashes; 214 Counter query_equal_representatives; 215 Counter query_inequality_found; 216 Counter query_not_found; 217 Counter find_halved; 218 Counter union_known; 219 Counter union_rank_swap; 220 Counter union_rank_increase; 221 Counter union_rank_zero; 222 Counter union_unknown; 223 Counter disunion_known_hash; 224 Counter disunion_known_inequality; 225 Counter disunion_unknown; 226 }; 227 228 struct SimpleEqualityCache { SimpleEqualityCacheSimpleEqualityCache229 explicit SimpleEqualityCache(Runtime& runtime) 230 : query_count(runtime, "simple_cache.query_count"), 231 query_equal_ids(runtime, "simple_cache.query_equal_ids"), 232 query_known_equality(runtime, "simple_cache.query_known_equality"), 233 known_equality_inserts(runtime, "simple_cache.known_equality_inserts") { 234 } 235 QuerySimpleEqualityCache236 std::optional<bool> Query(const Pair& comparison) { 237 ++query_count; 238 const auto& [id1, id2] = comparison; 239 if (id1 == id2) { 240 ++query_equal_ids; 241 return {true}; 242 } 243 if (known_equalities.contains(comparison)) { 244 ++query_known_equality; 245 return {true}; 246 } 247 return std::nullopt; 248 } 249 AllSameSimpleEqualityCache250 void AllSame(const std::vector<Pair>& comparisons) { 251 for (const auto& comparison : comparisons) { 252 ++known_equality_inserts; 253 known_equalities.insert(comparison); 254 } 255 } 256 AllDifferentSimpleEqualityCache257 void AllDifferent(const std::vector<Pair>&) {} 258 259 std::unordered_set<Pair> known_equalities; 260 261 Counter query_count; 262 Counter query_equal_ids; 263 Counter query_known_equality; 264 Counter known_equality_inserts; 265 }; 266 267 } // namespace stg 268 269 #endif // STG_EQUALITY_CACHE_H_ 270