xref: /aosp_15_r20/external/stg/equality_cache.h (revision 9e3b08ae94a55201065475453d799e8b1378bea6)
1*9e3b08aeSAndroid Build Coastguard Worker // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2*9e3b08aeSAndroid Build Coastguard Worker // -*- mode: C++ -*-
3*9e3b08aeSAndroid Build Coastguard Worker //
4*9e3b08aeSAndroid Build Coastguard Worker // Copyright 2022 Google LLC
5*9e3b08aeSAndroid Build Coastguard Worker //
6*9e3b08aeSAndroid Build Coastguard Worker // Licensed under the Apache License v2.0 with LLVM Exceptions (the
7*9e3b08aeSAndroid Build Coastguard Worker // "License"); you may not use this file except in compliance with the
8*9e3b08aeSAndroid Build Coastguard Worker // License.  You may obtain a copy of the License at
9*9e3b08aeSAndroid Build Coastguard Worker //
10*9e3b08aeSAndroid Build Coastguard Worker //     https://llvm.org/LICENSE.txt
11*9e3b08aeSAndroid Build Coastguard Worker //
12*9e3b08aeSAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
13*9e3b08aeSAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
14*9e3b08aeSAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15*9e3b08aeSAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
16*9e3b08aeSAndroid Build Coastguard Worker // limitations under the License.
17*9e3b08aeSAndroid Build Coastguard Worker //
18*9e3b08aeSAndroid Build Coastguard Worker // Author: Giuliano Procida
19*9e3b08aeSAndroid Build Coastguard Worker 
20*9e3b08aeSAndroid Build Coastguard Worker #ifndef STG_EQUALITY_CACHE_H_
21*9e3b08aeSAndroid Build Coastguard Worker #define STG_EQUALITY_CACHE_H_
22*9e3b08aeSAndroid Build Coastguard Worker 
23*9e3b08aeSAndroid Build Coastguard Worker #include <cstddef>
24*9e3b08aeSAndroid Build Coastguard Worker #include <optional>
25*9e3b08aeSAndroid Build Coastguard Worker #include <unordered_map>
26*9e3b08aeSAndroid Build Coastguard Worker #include <unordered_set>
27*9e3b08aeSAndroid Build Coastguard Worker #include <utility>
28*9e3b08aeSAndroid Build Coastguard Worker #include <vector>
29*9e3b08aeSAndroid Build Coastguard Worker 
30*9e3b08aeSAndroid Build Coastguard Worker #include "graph.h"
31*9e3b08aeSAndroid Build Coastguard Worker #include "hashing.h"
32*9e3b08aeSAndroid Build Coastguard Worker #include "runtime.h"
33*9e3b08aeSAndroid Build Coastguard Worker 
34*9e3b08aeSAndroid Build Coastguard Worker namespace stg {
35*9e3b08aeSAndroid Build Coastguard Worker 
36*9e3b08aeSAndroid Build Coastguard Worker // Equality cache - for use with the Equals function object
37*9e3b08aeSAndroid Build Coastguard Worker //
38*9e3b08aeSAndroid Build Coastguard Worker // This supports many features, some of probably limited long-term utility.
39*9e3b08aeSAndroid Build Coastguard Worker //
40*9e3b08aeSAndroid Build Coastguard Worker // It caches equalities (symmetrically) using union-find with path halving and
41*9e3b08aeSAndroid Build Coastguard Worker // union by rank.
42*9e3b08aeSAndroid Build Coastguard Worker //
43*9e3b08aeSAndroid Build Coastguard Worker // It caches inequalities (symmetrically); the inequalities are updated as part
44*9e3b08aeSAndroid Build Coastguard Worker // of the union operation.
45*9e3b08aeSAndroid Build Coastguard Worker //
46*9e3b08aeSAndroid Build Coastguard Worker // Node hashes such as those generated by the Fingerprint function object may be
47*9e3b08aeSAndroid Build Coastguard Worker // supplied to avoid equality testing when hashes differ.
48*9e3b08aeSAndroid Build Coastguard Worker struct EqualityCache {
EqualityCacheEqualityCache49*9e3b08aeSAndroid Build Coastguard Worker   EqualityCache(Runtime& runtime,
50*9e3b08aeSAndroid Build Coastguard Worker                 const std::unordered_map<Id, HashValue>& hashes)
51*9e3b08aeSAndroid Build Coastguard Worker       : hashes(hashes),
52*9e3b08aeSAndroid Build Coastguard Worker         query_count(runtime, "cache.query_count"),
53*9e3b08aeSAndroid Build Coastguard Worker         query_equal_ids(runtime, "cache.query_equal_ids"),
54*9e3b08aeSAndroid Build Coastguard Worker         query_unequal_hashes(runtime, "cache.query_unequal_hashes"),
55*9e3b08aeSAndroid Build Coastguard Worker         query_equal_representatives(runtime,
56*9e3b08aeSAndroid Build Coastguard Worker                                     "cache.query_equal_representatives"),
57*9e3b08aeSAndroid Build Coastguard Worker         query_inequality_found(runtime, "cache.query_inequality_found"),
58*9e3b08aeSAndroid Build Coastguard Worker         query_not_found(runtime, "cache.query_not_found"),
59*9e3b08aeSAndroid Build Coastguard Worker         find_halved(runtime, "cache.find_halved"),
60*9e3b08aeSAndroid Build Coastguard Worker         union_known(runtime, "cache.union_known"),
61*9e3b08aeSAndroid Build Coastguard Worker         union_rank_swap(runtime, "cache.union_rank_swap"),
62*9e3b08aeSAndroid Build Coastguard Worker         union_rank_increase(runtime, "cache.union_rank_increase"),
63*9e3b08aeSAndroid Build Coastguard Worker         union_rank_zero(runtime, "cache.union_rank_zero"),
64*9e3b08aeSAndroid Build Coastguard Worker         union_unknown(runtime, "cache.union_unknown"),
65*9e3b08aeSAndroid Build Coastguard Worker         disunion_known_hash(runtime, "cache.disunion_known_hash"),
66*9e3b08aeSAndroid Build Coastguard Worker         disunion_known_inequality(runtime, "cache.disunion_known_inequality"),
67*9e3b08aeSAndroid Build Coastguard Worker         disunion_unknown(runtime, "cache.disunion_unknown") {}
68*9e3b08aeSAndroid Build Coastguard Worker 
QueryEqualityCache69*9e3b08aeSAndroid Build Coastguard Worker   std::optional<bool> Query(const Pair& comparison) {
70*9e3b08aeSAndroid Build Coastguard Worker     ++query_count;
71*9e3b08aeSAndroid Build Coastguard Worker     const auto& [id1, id2] = comparison;
72*9e3b08aeSAndroid Build Coastguard Worker     if (id1 == id2) {
73*9e3b08aeSAndroid Build Coastguard Worker       ++query_equal_ids;
74*9e3b08aeSAndroid Build Coastguard Worker       return std::make_optional(true);
75*9e3b08aeSAndroid Build Coastguard Worker     }
76*9e3b08aeSAndroid Build Coastguard Worker     if (DistinctHashes(id1, id2)) {
77*9e3b08aeSAndroid Build Coastguard Worker       ++query_unequal_hashes;
78*9e3b08aeSAndroid Build Coastguard Worker       return std::make_optional(false);
79*9e3b08aeSAndroid Build Coastguard Worker     }
80*9e3b08aeSAndroid Build Coastguard Worker     const Id fid1 = Find(id1);
81*9e3b08aeSAndroid Build Coastguard Worker     const Id fid2 = Find(id2);
82*9e3b08aeSAndroid Build Coastguard Worker     if (fid1 == fid2) {
83*9e3b08aeSAndroid Build Coastguard Worker       ++query_equal_representatives;
84*9e3b08aeSAndroid Build Coastguard Worker       return std::make_optional(true);
85*9e3b08aeSAndroid Build Coastguard Worker     }
86*9e3b08aeSAndroid Build Coastguard Worker     auto not_it = inequalities.find(fid1);
87*9e3b08aeSAndroid Build Coastguard Worker     if (not_it != inequalities.end()) {
88*9e3b08aeSAndroid Build Coastguard Worker       auto not_it2 = not_it->second.find(fid2);
89*9e3b08aeSAndroid Build Coastguard Worker       if (not_it2 != not_it->second.end()) {
90*9e3b08aeSAndroid Build Coastguard Worker         ++query_inequality_found;
91*9e3b08aeSAndroid Build Coastguard Worker         return std::make_optional(false);
92*9e3b08aeSAndroid Build Coastguard Worker       }
93*9e3b08aeSAndroid Build Coastguard Worker     }
94*9e3b08aeSAndroid Build Coastguard Worker     ++query_not_found;
95*9e3b08aeSAndroid Build Coastguard Worker     return std::nullopt;
96*9e3b08aeSAndroid Build Coastguard Worker   }
97*9e3b08aeSAndroid Build Coastguard Worker 
AllSameEqualityCache98*9e3b08aeSAndroid Build Coastguard Worker   void AllSame(const std::vector<Pair>& comparisons) {
99*9e3b08aeSAndroid Build Coastguard Worker     for (const auto& [id1, id2] : comparisons) {
100*9e3b08aeSAndroid Build Coastguard Worker       Union(id1, id2);
101*9e3b08aeSAndroid Build Coastguard Worker     }
102*9e3b08aeSAndroid Build Coastguard Worker   }
103*9e3b08aeSAndroid Build Coastguard Worker 
AllDifferentEqualityCache104*9e3b08aeSAndroid Build Coastguard Worker   void AllDifferent(const std::vector<Pair>& comparisons) {
105*9e3b08aeSAndroid Build Coastguard Worker     for (const auto& [id1, id2] : comparisons) {
106*9e3b08aeSAndroid Build Coastguard Worker       Disunion(id1, id2);
107*9e3b08aeSAndroid Build Coastguard Worker     }
108*9e3b08aeSAndroid Build Coastguard Worker   }
109*9e3b08aeSAndroid Build Coastguard Worker 
DistinctHashesEqualityCache110*9e3b08aeSAndroid Build Coastguard Worker   bool DistinctHashes(Id id1, Id id2) {
111*9e3b08aeSAndroid Build Coastguard Worker     const auto it1 = hashes.find(id1);
112*9e3b08aeSAndroid Build Coastguard Worker     const auto it2 = hashes.find(id2);
113*9e3b08aeSAndroid Build Coastguard Worker     return it1 != hashes.end() && it2 != hashes.end()
114*9e3b08aeSAndroid Build Coastguard Worker         && it1->second != it2->second;
115*9e3b08aeSAndroid Build Coastguard Worker   }
116*9e3b08aeSAndroid Build Coastguard Worker 
FindEqualityCache117*9e3b08aeSAndroid Build Coastguard Worker   Id Find(Id id) {
118*9e3b08aeSAndroid Build Coastguard Worker     // path halving
119*9e3b08aeSAndroid Build Coastguard Worker     while (true) {
120*9e3b08aeSAndroid Build Coastguard Worker       auto it = mapping.find(id);
121*9e3b08aeSAndroid Build Coastguard Worker       if (it == mapping.end()) {
122*9e3b08aeSAndroid Build Coastguard Worker         return id;
123*9e3b08aeSAndroid Build Coastguard Worker       }
124*9e3b08aeSAndroid Build Coastguard Worker       auto& parent = it->second;
125*9e3b08aeSAndroid Build Coastguard Worker       auto parent_it = mapping.find(parent);
126*9e3b08aeSAndroid Build Coastguard Worker       if (parent_it == mapping.end()) {
127*9e3b08aeSAndroid Build Coastguard Worker         return parent;
128*9e3b08aeSAndroid Build Coastguard Worker       }
129*9e3b08aeSAndroid Build Coastguard Worker       auto parent_parent = parent_it->second;
130*9e3b08aeSAndroid Build Coastguard Worker       id = parent = parent_parent;
131*9e3b08aeSAndroid Build Coastguard Worker       ++find_halved;
132*9e3b08aeSAndroid Build Coastguard Worker     }
133*9e3b08aeSAndroid Build Coastguard Worker   }
134*9e3b08aeSAndroid Build Coastguard Worker 
GetRankEqualityCache135*9e3b08aeSAndroid Build Coastguard Worker   size_t GetRank(Id id) {
136*9e3b08aeSAndroid Build Coastguard Worker     auto it = rank.find(id);
137*9e3b08aeSAndroid Build Coastguard Worker     return it == rank.end() ? 0 : it->second;
138*9e3b08aeSAndroid Build Coastguard Worker   }
139*9e3b08aeSAndroid Build Coastguard Worker 
SetRankEqualityCache140*9e3b08aeSAndroid Build Coastguard Worker   void SetRank(Id id, size_t r) {
141*9e3b08aeSAndroid Build Coastguard Worker     if (r) {
142*9e3b08aeSAndroid Build Coastguard Worker       rank[id] = r;
143*9e3b08aeSAndroid Build Coastguard Worker     } else {
144*9e3b08aeSAndroid Build Coastguard Worker       rank.erase(id);
145*9e3b08aeSAndroid Build Coastguard Worker     }
146*9e3b08aeSAndroid Build Coastguard Worker   }
147*9e3b08aeSAndroid Build Coastguard Worker 
UnionEqualityCache148*9e3b08aeSAndroid Build Coastguard Worker   void Union(Id id1, Id id2) {
149*9e3b08aeSAndroid Build Coastguard Worker     Check(!DistinctHashes(id1, id2)) << "union with distinct hashes";
150*9e3b08aeSAndroid Build Coastguard Worker     Id fid1 = Find(id1);
151*9e3b08aeSAndroid Build Coastguard Worker     Id fid2 = Find(id2);
152*9e3b08aeSAndroid Build Coastguard Worker     if (fid1 == fid2) {
153*9e3b08aeSAndroid Build Coastguard Worker       ++union_known;
154*9e3b08aeSAndroid Build Coastguard Worker       return;
155*9e3b08aeSAndroid Build Coastguard Worker     }
156*9e3b08aeSAndroid Build Coastguard Worker     size_t rank1 = GetRank(fid1);
157*9e3b08aeSAndroid Build Coastguard Worker     size_t rank2 = GetRank(fid2);
158*9e3b08aeSAndroid Build Coastguard Worker     if (rank1 > rank2) {
159*9e3b08aeSAndroid Build Coastguard Worker       std::swap(fid1, fid2);
160*9e3b08aeSAndroid Build Coastguard Worker       std::swap(rank1, rank2);
161*9e3b08aeSAndroid Build Coastguard Worker       ++union_rank_swap;
162*9e3b08aeSAndroid Build Coastguard Worker     }
163*9e3b08aeSAndroid Build Coastguard Worker     // rank1 <= rank2
164*9e3b08aeSAndroid Build Coastguard Worker     if (rank1 == rank2) {
165*9e3b08aeSAndroid Build Coastguard Worker       SetRank(fid2, rank2 + 1);
166*9e3b08aeSAndroid Build Coastguard Worker       ++union_rank_increase;
167*9e3b08aeSAndroid Build Coastguard Worker     }
168*9e3b08aeSAndroid Build Coastguard Worker     if (rank1) {
169*9e3b08aeSAndroid Build Coastguard Worker       SetRank(fid1, 0);
170*9e3b08aeSAndroid Build Coastguard Worker       ++union_rank_zero;
171*9e3b08aeSAndroid Build Coastguard Worker     }
172*9e3b08aeSAndroid Build Coastguard Worker     mapping.insert({fid1, fid2});
173*9e3b08aeSAndroid Build Coastguard Worker     ++union_unknown;
174*9e3b08aeSAndroid Build Coastguard Worker 
175*9e3b08aeSAndroid Build Coastguard Worker     // move inequalities from fid1 to fid2
176*9e3b08aeSAndroid Build Coastguard Worker     auto not_it = inequalities.find(fid1);
177*9e3b08aeSAndroid Build Coastguard Worker     if (not_it != inequalities.end()) {
178*9e3b08aeSAndroid Build Coastguard Worker       auto& source = not_it->second;
179*9e3b08aeSAndroid Build Coastguard Worker       auto& target = inequalities[fid2];
180*9e3b08aeSAndroid Build Coastguard Worker       for (auto fid : source) {
181*9e3b08aeSAndroid Build Coastguard Worker         Check(fid != fid2) << "union of unequal";
182*9e3b08aeSAndroid Build Coastguard Worker         target.insert(fid);
183*9e3b08aeSAndroid Build Coastguard Worker         auto& target2 = inequalities[fid];
184*9e3b08aeSAndroid Build Coastguard Worker         target2.erase(fid1);
185*9e3b08aeSAndroid Build Coastguard Worker         target2.insert(fid2);
186*9e3b08aeSAndroid Build Coastguard Worker       }
187*9e3b08aeSAndroid Build Coastguard Worker     }
188*9e3b08aeSAndroid Build Coastguard Worker   }
189*9e3b08aeSAndroid Build Coastguard Worker 
DisunionEqualityCache190*9e3b08aeSAndroid Build Coastguard Worker   void Disunion(Id id1, Id id2) {
191*9e3b08aeSAndroid Build Coastguard Worker     if (DistinctHashes(id1, id2)) {
192*9e3b08aeSAndroid Build Coastguard Worker       ++disunion_known_hash;
193*9e3b08aeSAndroid Build Coastguard Worker       return;
194*9e3b08aeSAndroid Build Coastguard Worker     }
195*9e3b08aeSAndroid Build Coastguard Worker     const Id fid1 = Find(id1);
196*9e3b08aeSAndroid Build Coastguard Worker     const Id fid2 = Find(id2);
197*9e3b08aeSAndroid Build Coastguard Worker     Check(fid1 != fid2) << "disunion of equal";
198*9e3b08aeSAndroid Build Coastguard Worker     if (inequalities[fid1].insert(fid2).second) {
199*9e3b08aeSAndroid Build Coastguard Worker       inequalities[fid2].insert(fid1);
200*9e3b08aeSAndroid Build Coastguard Worker       ++disunion_unknown;
201*9e3b08aeSAndroid Build Coastguard Worker     } else {
202*9e3b08aeSAndroid Build Coastguard Worker       ++disunion_known_inequality;
203*9e3b08aeSAndroid Build Coastguard Worker     }
204*9e3b08aeSAndroid Build Coastguard Worker   }
205*9e3b08aeSAndroid Build Coastguard Worker 
206*9e3b08aeSAndroid Build Coastguard Worker   const std::unordered_map<Id, HashValue>& hashes;
207*9e3b08aeSAndroid Build Coastguard Worker   std::unordered_map<Id, Id> mapping;
208*9e3b08aeSAndroid Build Coastguard Worker   std::unordered_map<Id, size_t> rank;
209*9e3b08aeSAndroid Build Coastguard Worker   std::unordered_map<Id, std::unordered_set<Id>> inequalities;
210*9e3b08aeSAndroid Build Coastguard Worker 
211*9e3b08aeSAndroid Build Coastguard Worker   Counter query_count;
212*9e3b08aeSAndroid Build Coastguard Worker   Counter query_equal_ids;
213*9e3b08aeSAndroid Build Coastguard Worker   Counter query_unequal_hashes;
214*9e3b08aeSAndroid Build Coastguard Worker   Counter query_equal_representatives;
215*9e3b08aeSAndroid Build Coastguard Worker   Counter query_inequality_found;
216*9e3b08aeSAndroid Build Coastguard Worker   Counter query_not_found;
217*9e3b08aeSAndroid Build Coastguard Worker   Counter find_halved;
218*9e3b08aeSAndroid Build Coastguard Worker   Counter union_known;
219*9e3b08aeSAndroid Build Coastguard Worker   Counter union_rank_swap;
220*9e3b08aeSAndroid Build Coastguard Worker   Counter union_rank_increase;
221*9e3b08aeSAndroid Build Coastguard Worker   Counter union_rank_zero;
222*9e3b08aeSAndroid Build Coastguard Worker   Counter union_unknown;
223*9e3b08aeSAndroid Build Coastguard Worker   Counter disunion_known_hash;
224*9e3b08aeSAndroid Build Coastguard Worker   Counter disunion_known_inequality;
225*9e3b08aeSAndroid Build Coastguard Worker   Counter disunion_unknown;
226*9e3b08aeSAndroid Build Coastguard Worker };
227*9e3b08aeSAndroid Build Coastguard Worker 
228*9e3b08aeSAndroid Build Coastguard Worker struct SimpleEqualityCache {
SimpleEqualityCacheSimpleEqualityCache229*9e3b08aeSAndroid Build Coastguard Worker   explicit SimpleEqualityCache(Runtime& runtime)
230*9e3b08aeSAndroid Build Coastguard Worker       : query_count(runtime, "simple_cache.query_count"),
231*9e3b08aeSAndroid Build Coastguard Worker         query_equal_ids(runtime, "simple_cache.query_equal_ids"),
232*9e3b08aeSAndroid Build Coastguard Worker         query_known_equality(runtime, "simple_cache.query_known_equality"),
233*9e3b08aeSAndroid Build Coastguard Worker         known_equality_inserts(runtime, "simple_cache.known_equality_inserts") {
234*9e3b08aeSAndroid Build Coastguard Worker   }
235*9e3b08aeSAndroid Build Coastguard Worker 
QuerySimpleEqualityCache236*9e3b08aeSAndroid Build Coastguard Worker   std::optional<bool> Query(const Pair& comparison) {
237*9e3b08aeSAndroid Build Coastguard Worker     ++query_count;
238*9e3b08aeSAndroid Build Coastguard Worker     const auto& [id1, id2] = comparison;
239*9e3b08aeSAndroid Build Coastguard Worker     if (id1 == id2) {
240*9e3b08aeSAndroid Build Coastguard Worker       ++query_equal_ids;
241*9e3b08aeSAndroid Build Coastguard Worker       return {true};
242*9e3b08aeSAndroid Build Coastguard Worker     }
243*9e3b08aeSAndroid Build Coastguard Worker     if (known_equalities.contains(comparison)) {
244*9e3b08aeSAndroid Build Coastguard Worker       ++query_known_equality;
245*9e3b08aeSAndroid Build Coastguard Worker       return {true};
246*9e3b08aeSAndroid Build Coastguard Worker     }
247*9e3b08aeSAndroid Build Coastguard Worker     return std::nullopt;
248*9e3b08aeSAndroid Build Coastguard Worker   }
249*9e3b08aeSAndroid Build Coastguard Worker 
AllSameSimpleEqualityCache250*9e3b08aeSAndroid Build Coastguard Worker   void AllSame(const std::vector<Pair>& comparisons) {
251*9e3b08aeSAndroid Build Coastguard Worker     for (const auto& comparison : comparisons) {
252*9e3b08aeSAndroid Build Coastguard Worker       ++known_equality_inserts;
253*9e3b08aeSAndroid Build Coastguard Worker       known_equalities.insert(comparison);
254*9e3b08aeSAndroid Build Coastguard Worker     }
255*9e3b08aeSAndroid Build Coastguard Worker   }
256*9e3b08aeSAndroid Build Coastguard Worker 
AllDifferentSimpleEqualityCache257*9e3b08aeSAndroid Build Coastguard Worker   void AllDifferent(const std::vector<Pair>&) {}
258*9e3b08aeSAndroid Build Coastguard Worker 
259*9e3b08aeSAndroid Build Coastguard Worker   std::unordered_set<Pair> known_equalities;
260*9e3b08aeSAndroid Build Coastguard Worker 
261*9e3b08aeSAndroid Build Coastguard Worker   Counter query_count;
262*9e3b08aeSAndroid Build Coastguard Worker   Counter query_equal_ids;
263*9e3b08aeSAndroid Build Coastguard Worker   Counter query_known_equality;
264*9e3b08aeSAndroid Build Coastguard Worker   Counter known_equality_inserts;
265*9e3b08aeSAndroid Build Coastguard Worker };
266*9e3b08aeSAndroid Build Coastguard Worker 
267*9e3b08aeSAndroid Build Coastguard Worker }  // namespace stg
268*9e3b08aeSAndroid Build Coastguard Worker 
269*9e3b08aeSAndroid Build Coastguard Worker #endif  // STG_EQUALITY_CACHE_H_
270