xref: /aosp_15_r20/external/stg/equality_cache.h (revision 9e3b08ae94a55201065475453d799e8b1378bea6)
1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2 // -*- mode: C++ -*-
3 //
4 // Copyright 2022 Google LLC
5 //
6 // Licensed under the Apache License v2.0 with LLVM Exceptions (the
7 // "License"); you may not use this file except in compliance with the
8 // License.  You may obtain a copy of the License at
9 //
10 //     https://llvm.org/LICENSE.txt
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 // Author: Giuliano Procida
19 
20 #ifndef STG_EQUALITY_CACHE_H_
21 #define STG_EQUALITY_CACHE_H_
22 
23 #include <cstddef>
24 #include <optional>
25 #include <unordered_map>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29 
30 #include "graph.h"
31 #include "hashing.h"
32 #include "runtime.h"
33 
34 namespace stg {
35 
36 // Equality cache - for use with the Equals function object
37 //
38 // This supports many features, some of probably limited long-term utility.
39 //
40 // It caches equalities (symmetrically) using union-find with path halving and
41 // union by rank.
42 //
43 // It caches inequalities (symmetrically); the inequalities are updated as part
44 // of the union operation.
45 //
46 // Node hashes such as those generated by the Fingerprint function object may be
47 // supplied to avoid equality testing when hashes differ.
48 struct EqualityCache {
EqualityCacheEqualityCache49   EqualityCache(Runtime& runtime,
50                 const std::unordered_map<Id, HashValue>& hashes)
51       : hashes(hashes),
52         query_count(runtime, "cache.query_count"),
53         query_equal_ids(runtime, "cache.query_equal_ids"),
54         query_unequal_hashes(runtime, "cache.query_unequal_hashes"),
55         query_equal_representatives(runtime,
56                                     "cache.query_equal_representatives"),
57         query_inequality_found(runtime, "cache.query_inequality_found"),
58         query_not_found(runtime, "cache.query_not_found"),
59         find_halved(runtime, "cache.find_halved"),
60         union_known(runtime, "cache.union_known"),
61         union_rank_swap(runtime, "cache.union_rank_swap"),
62         union_rank_increase(runtime, "cache.union_rank_increase"),
63         union_rank_zero(runtime, "cache.union_rank_zero"),
64         union_unknown(runtime, "cache.union_unknown"),
65         disunion_known_hash(runtime, "cache.disunion_known_hash"),
66         disunion_known_inequality(runtime, "cache.disunion_known_inequality"),
67         disunion_unknown(runtime, "cache.disunion_unknown") {}
68 
QueryEqualityCache69   std::optional<bool> Query(const Pair& comparison) {
70     ++query_count;
71     const auto& [id1, id2] = comparison;
72     if (id1 == id2) {
73       ++query_equal_ids;
74       return std::make_optional(true);
75     }
76     if (DistinctHashes(id1, id2)) {
77       ++query_unequal_hashes;
78       return std::make_optional(false);
79     }
80     const Id fid1 = Find(id1);
81     const Id fid2 = Find(id2);
82     if (fid1 == fid2) {
83       ++query_equal_representatives;
84       return std::make_optional(true);
85     }
86     auto not_it = inequalities.find(fid1);
87     if (not_it != inequalities.end()) {
88       auto not_it2 = not_it->second.find(fid2);
89       if (not_it2 != not_it->second.end()) {
90         ++query_inequality_found;
91         return std::make_optional(false);
92       }
93     }
94     ++query_not_found;
95     return std::nullopt;
96   }
97 
AllSameEqualityCache98   void AllSame(const std::vector<Pair>& comparisons) {
99     for (const auto& [id1, id2] : comparisons) {
100       Union(id1, id2);
101     }
102   }
103 
AllDifferentEqualityCache104   void AllDifferent(const std::vector<Pair>& comparisons) {
105     for (const auto& [id1, id2] : comparisons) {
106       Disunion(id1, id2);
107     }
108   }
109 
DistinctHashesEqualityCache110   bool DistinctHashes(Id id1, Id id2) {
111     const auto it1 = hashes.find(id1);
112     const auto it2 = hashes.find(id2);
113     return it1 != hashes.end() && it2 != hashes.end()
114         && it1->second != it2->second;
115   }
116 
FindEqualityCache117   Id Find(Id id) {
118     // path halving
119     while (true) {
120       auto it = mapping.find(id);
121       if (it == mapping.end()) {
122         return id;
123       }
124       auto& parent = it->second;
125       auto parent_it = mapping.find(parent);
126       if (parent_it == mapping.end()) {
127         return parent;
128       }
129       auto parent_parent = parent_it->second;
130       id = parent = parent_parent;
131       ++find_halved;
132     }
133   }
134 
GetRankEqualityCache135   size_t GetRank(Id id) {
136     auto it = rank.find(id);
137     return it == rank.end() ? 0 : it->second;
138   }
139 
SetRankEqualityCache140   void SetRank(Id id, size_t r) {
141     if (r) {
142       rank[id] = r;
143     } else {
144       rank.erase(id);
145     }
146   }
147 
UnionEqualityCache148   void Union(Id id1, Id id2) {
149     Check(!DistinctHashes(id1, id2)) << "union with distinct hashes";
150     Id fid1 = Find(id1);
151     Id fid2 = Find(id2);
152     if (fid1 == fid2) {
153       ++union_known;
154       return;
155     }
156     size_t rank1 = GetRank(fid1);
157     size_t rank2 = GetRank(fid2);
158     if (rank1 > rank2) {
159       std::swap(fid1, fid2);
160       std::swap(rank1, rank2);
161       ++union_rank_swap;
162     }
163     // rank1 <= rank2
164     if (rank1 == rank2) {
165       SetRank(fid2, rank2 + 1);
166       ++union_rank_increase;
167     }
168     if (rank1) {
169       SetRank(fid1, 0);
170       ++union_rank_zero;
171     }
172     mapping.insert({fid1, fid2});
173     ++union_unknown;
174 
175     // move inequalities from fid1 to fid2
176     auto not_it = inequalities.find(fid1);
177     if (not_it != inequalities.end()) {
178       auto& source = not_it->second;
179       auto& target = inequalities[fid2];
180       for (auto fid : source) {
181         Check(fid != fid2) << "union of unequal";
182         target.insert(fid);
183         auto& target2 = inequalities[fid];
184         target2.erase(fid1);
185         target2.insert(fid2);
186       }
187     }
188   }
189 
DisunionEqualityCache190   void Disunion(Id id1, Id id2) {
191     if (DistinctHashes(id1, id2)) {
192       ++disunion_known_hash;
193       return;
194     }
195     const Id fid1 = Find(id1);
196     const Id fid2 = Find(id2);
197     Check(fid1 != fid2) << "disunion of equal";
198     if (inequalities[fid1].insert(fid2).second) {
199       inequalities[fid2].insert(fid1);
200       ++disunion_unknown;
201     } else {
202       ++disunion_known_inequality;
203     }
204   }
205 
206   const std::unordered_map<Id, HashValue>& hashes;
207   std::unordered_map<Id, Id> mapping;
208   std::unordered_map<Id, size_t> rank;
209   std::unordered_map<Id, std::unordered_set<Id>> inequalities;
210 
211   Counter query_count;
212   Counter query_equal_ids;
213   Counter query_unequal_hashes;
214   Counter query_equal_representatives;
215   Counter query_inequality_found;
216   Counter query_not_found;
217   Counter find_halved;
218   Counter union_known;
219   Counter union_rank_swap;
220   Counter union_rank_increase;
221   Counter union_rank_zero;
222   Counter union_unknown;
223   Counter disunion_known_hash;
224   Counter disunion_known_inequality;
225   Counter disunion_unknown;
226 };
227 
228 struct SimpleEqualityCache {
SimpleEqualityCacheSimpleEqualityCache229   explicit SimpleEqualityCache(Runtime& runtime)
230       : query_count(runtime, "simple_cache.query_count"),
231         query_equal_ids(runtime, "simple_cache.query_equal_ids"),
232         query_known_equality(runtime, "simple_cache.query_known_equality"),
233         known_equality_inserts(runtime, "simple_cache.known_equality_inserts") {
234   }
235 
QuerySimpleEqualityCache236   std::optional<bool> Query(const Pair& comparison) {
237     ++query_count;
238     const auto& [id1, id2] = comparison;
239     if (id1 == id2) {
240       ++query_equal_ids;
241       return {true};
242     }
243     if (known_equalities.contains(comparison)) {
244       ++query_known_equality;
245       return {true};
246     }
247     return std::nullopt;
248   }
249 
AllSameSimpleEqualityCache250   void AllSame(const std::vector<Pair>& comparisons) {
251     for (const auto& comparison : comparisons) {
252       ++known_equality_inserts;
253       known_equalities.insert(comparison);
254     }
255   }
256 
AllDifferentSimpleEqualityCache257   void AllDifferent(const std::vector<Pair>&) {}
258 
259   std::unordered_set<Pair> known_equalities;
260 
261   Counter query_count;
262   Counter query_equal_ids;
263   Counter query_known_equality;
264   Counter known_equality_inserts;
265 };
266 
267 }  // namespace stg
268 
269 #endif  // STG_EQUALITY_CACHE_H_
270