xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/ToposortStructs.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <functional>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <vector>
12 
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/ToposortStructs.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 
19 using namespace sh;
20 
21 ////////////////////////////////////////////////////////////////////////////////
22 
23 namespace
24 {
25 
26 template <typename T>
27 using Edges = std::unordered_set<T>;
28 
29 template <typename T>
30 using Graph = std::unordered_map<T, Edges<T>>;
31 
32 struct EdgeComparator
33 {
operator ()__anona131e7250111::EdgeComparator34     bool operator()(const TStructure *s1, const TStructure *s2) { return s2->name() < s1->name(); }
35 };
36 
BuildGraphImpl(SymbolEnv & symbolEnv,Graph<const TStructure * > & g,const TStructure * s)37 void BuildGraphImpl(SymbolEnv &symbolEnv, Graph<const TStructure *> &g, const TStructure *s)
38 {
39     if (g.find(s) != g.end())
40     {
41         return;
42     }
43 
44     Edges<const TStructure *> &es = g[s];
45 
46     const TFieldList &fs = s->fields();
47     for (const TField *f : fs)
48     {
49         if (const TStructure *z = symbolEnv.remap(f->type()->getStruct()))
50         {
51             es.insert(z);
52             BuildGraphImpl(symbolEnv, g, z);
53             Edges<const TStructure *> &ez = g[z];
54             es.insert(ez.begin(), ez.end());
55         }
56     }
57 }
58 
BuildGraph(SymbolEnv & symbolEnv,const std::vector<const TStructure * > & structs)59 Graph<const TStructure *> BuildGraph(SymbolEnv &symbolEnv,
60                                      const std::vector<const TStructure *> &structs)
61 {
62     Graph<const TStructure *> g;
63     for (const TStructure *s : structs)
64     {
65         BuildGraphImpl(symbolEnv, g, s);
66     }
67     return g;
68 }
69 
SortEdges(const std::unordered_set<const TStructure * > & structs)70 std::vector<const TStructure *> SortEdges(const std::unordered_set<const TStructure *> &structs)
71 {
72     std::vector<const TStructure *> sorted;
73     sorted.reserve(structs.size());
74     sorted.insert(sorted.begin(), structs.begin(), structs.end());
75     std::sort(sorted.begin(), sorted.end(), EdgeComparator());
76     return sorted;
77 }
78 
79 // Algorthm: https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
80 // Note that the algorithm is modified to visit nodes in sorted order. This
81 // ensures consistent results. Without this, the returned order (in so far as
82 // leaf nodes) is undefined, because iterating over an unordered_set of pointers
83 // depends upon the actual pointer values. Consistent results is important for
84 // code that keys off the string of shaders for caching.
85 template <typename T>
Toposort(const Graph<T> & g)86 std::vector<T> Toposort(const Graph<T> &g)
87 {
88     // nodes with temporary mark
89     std::unordered_set<T> temps;
90 
91     // nodes without permanent mark
92     std::unordered_set<T> invPerms;
93     for (const auto &entry : g)
94     {
95         invPerms.insert(entry.first);
96     }
97 
98     // L <- Empty list that will contain the sorted elements
99     std::vector<T> L;
100 
101     // function visit(node n)
102     std::function<void(T)> visit = [&](T n) -> void {
103         // if n has a permanent mark then
104         if (invPerms.find(n) == invPerms.end())
105         {
106             // return
107             return;
108         }
109         // if n has a temporary mark then
110         if (temps.find(n) != temps.end())
111         {
112             // stop   (not a DAG)
113             UNREACHABLE();
114         }
115 
116         // mark n with a temporary mark
117         temps.insert(n);
118 
119         // for each node m with an edge from n to m do
120         auto enIter = g.find(n);
121         ASSERT(enIter != g.end());
122 
123         std::vector<T> sorted = SortEdges(enIter->second);
124         for (T m : sorted)
125         {
126             // visit(m)
127             visit(m);
128         }
129 
130         // remove temporary mark from n
131         temps.erase(n);
132         // mark n with a permanent mark
133         invPerms.erase(n);
134         // add n to head of L
135         L.push_back(n);
136     };
137 
138     // while exists nodes without a permanent mark do
139     while (!invPerms.empty())
140     {
141         // select an unmarked node n
142         std::vector<T> sorted = SortEdges(invPerms);
143         T n                   = *sorted.begin();
144         // visit(n)
145         visit(n);
146     }
147 
148     return L;
149 }
150 
CreateStructEqualityFunction(TSymbolTable & symbolTable,const TStructure & aStructType,const std::unordered_map<const TStructure *,const TFunction * > & equalityFunctions)151 TIntermFunctionDefinition *CreateStructEqualityFunction(
152     TSymbolTable &symbolTable,
153     const TStructure &aStructType,
154     const std::unordered_map<const TStructure *, const TFunction *> &equalityFunctions)
155 {
156     auto &funcEquality =
157         *new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
158                        new TType(TBasicType::EbtBool), true);
159 
160     auto &aStruct = CreateInstanceVariable(symbolTable, aStructType, Name("a"));
161     auto &bStruct = CreateInstanceVariable(symbolTable, aStructType, Name("b"));
162     funcEquality.addParameter(&aStruct);
163     funcEquality.addParameter(&bStruct);
164 
165     auto &bodyEquality = *new TIntermBlock();
166     std::vector<TIntermTyped *> andNodes;
167 
168     const TFieldList &aFields = aStructType.fields();
169     const size_t size         = aFields.size();
170 
171     auto testEquality = [&](TIntermTyped &a, TIntermTyped &b) -> TIntermTyped * {
172         ASSERT(a.getType() == b.getType());
173         const TType &type = a.getType();
174         if (const TStructure *structure = type.getStruct(); structure != nullptr)
175         {
176             auto func = equalityFunctions.find(structure);
177             if (func != equalityFunctions.end())
178             {
179                 return TIntermAggregate::CreateFunctionCall(*func->second,
180                                                             new TIntermSequence{&a, &b});
181             }
182             UNREACHABLE();
183         }
184         return new TIntermBinary(TOperator::EOpEqual, &a, &b);
185     };
186 
187     for (size_t idx = 0; idx < size; ++idx)
188     {
189         const TField &aField    = *aFields[idx];
190         const TType &aFieldType = *aField.type();
191         const Name aFieldName(aField);
192 
193         if (aFieldType.isArray())
194         {
195             ASSERT(!aFieldType.isArrayOfArrays());  // TODO
196             int dim = aFieldType.getOutermostArraySize();
197             for (int d = 0; d < dim; ++d)
198             {
199                 auto &aAccess = AccessIndex(AccessField(aStruct, aFieldName), d);
200                 auto &bAccess = AccessIndex(AccessField(bStruct, aFieldName), d);
201                 auto *eqNode  = testEquality(bAccess, aAccess);
202                 andNodes.push_back(eqNode);
203             }
204         }
205         else
206         {
207             auto &aAccess = AccessField(aStruct, aFieldName);
208             auto &bAccess = AccessField(bStruct, aFieldName);
209             auto *eqNode  = testEquality(bAccess, aAccess);
210             andNodes.push_back(eqNode);
211         }
212     }
213 
214     ASSERT(andNodes.size() > 0);  // Empty structs are not allowed in GLSL
215     TIntermTyped *outNode = andNodes.back();
216     andNodes.pop_back();
217     for (TIntermTyped *andNode : andNodes)
218     {
219         outNode = new TIntermBinary(TOperator::EOpLogicalAnd, andNode, outNode);
220     }
221     bodyEquality.appendStatement(new TIntermBranch(TOperator::EOpReturn, outNode));
222     auto *funcProtoEquality = new TIntermFunctionPrototype(&funcEquality);
223     return new TIntermFunctionDefinition(funcProtoEquality, &bodyEquality);
224 }
225 
226 struct DeclaredStructure
227 {
228     TIntermDeclaration *declNode;
229     const TStructure *structure;
230 };
231 
GetAsDeclaredStructure(SymbolEnv & symbolEnv,TIntermNode & node,DeclaredStructure & out)232 bool GetAsDeclaredStructure(SymbolEnv &symbolEnv, TIntermNode &node, DeclaredStructure &out)
233 {
234     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
235     {
236         ASSERT(declNode->getChildCount() == 1);
237         TIntermNode &childNode = *declNode->getChildNode(0);
238 
239         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
240         {
241             const TVariable &var = symbolNode->variable();
242             const TType &type    = var.getType();
243             if (const TStructure *structure = symbolEnv.remap(type.getStruct()))
244             {
245                 if (type.isStructSpecifier())
246                 {
247                     out.declNode  = declNode;
248                     out.structure = structure;
249                     return true;
250                 }
251             }
252         }
253     }
254     return false;
255 }
256 
257 class FindStructEqualityUse : public TIntermTraverser
258 {
259   public:
260     SymbolEnv &mSymbolEnv;
261     std::unordered_set<const TStructure *> mUsedStructs;
262 
FindStructEqualityUse(SymbolEnv & symbolEnv)263     FindStructEqualityUse(SymbolEnv &symbolEnv)
264         : TIntermTraverser(false, false, true), mSymbolEnv(symbolEnv)
265     {}
266 
visitBinary(Visit,TIntermBinary * binary)267     bool visitBinary(Visit, TIntermBinary *binary) override
268     {
269         const TOperator op = binary->getOp();
270 
271         switch (op)
272         {
273             case TOperator::EOpEqual:
274             case TOperator::EOpNotEqual:
275             {
276                 const TType &leftType  = binary->getLeft()->getType();
277                 const TType &rightType = binary->getRight()->getType();
278                 ASSERT(leftType.getStruct() == rightType.getStruct());
279                 if (const TStructure *structure = mSymbolEnv.remap(leftType.getStruct()))
280                 {
281                     useStruct(*structure);
282                 }
283             }
284             break;
285 
286             default:
287                 break;
288         }
289 
290         return true;
291     }
292 
293   private:
useStruct(const TStructure & structure)294     void useStruct(const TStructure &structure)
295     {
296         if (mUsedStructs.insert(&structure).second)
297         {
298             for (const TField *field : structure.fields())
299             {
300                 if (const TStructure *subStruct = mSymbolEnv.remap(field->type()->getStruct()))
301                 {
302                     useStruct(*subStruct);
303                 }
304             }
305         }
306     }
307 };
308 
309 }  // anonymous namespace
310 
311 ////////////////////////////////////////////////////////////////////////////////
312 
ToposortStructs(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root,ProgramPreludeConfig & ppc)313 bool sh::ToposortStructs(TCompiler &compiler,
314                          SymbolEnv &symbolEnv,
315                          TIntermBlock &root,
316                          ProgramPreludeConfig &ppc)
317 {
318     FindStructEqualityUse finder(symbolEnv);
319     root.traverse(&finder);
320     auto &usedStructs = finder.mUsedStructs;
321 
322     std::vector<DeclaredStructure> declaredStructs;
323     std::vector<TIntermNode *> nonStructStmtNodes;
324 
325     {
326         DeclaredStructure declaredStruct;
327         const size_t stmtCount = root.getChildCount();
328         for (size_t i = 0; i < stmtCount; ++i)
329         {
330             TIntermNode &stmtNode = *root.getChildNode(i);
331             if (GetAsDeclaredStructure(symbolEnv, stmtNode, declaredStruct))
332             {
333                 declaredStructs.push_back(declaredStruct);
334             }
335             else
336             {
337                 nonStructStmtNodes.push_back(&stmtNode);
338             }
339         }
340     }
341 
342     {
343         std::vector<const TStructure *> structs;
344         std::unordered_map<const TStructure *, DeclaredStructure> rawToDeclared;
345 
346         for (const DeclaredStructure &d : declaredStructs)
347         {
348             structs.push_back(d.structure);
349             ASSERT(rawToDeclared.find(d.structure) == rawToDeclared.end());
350             rawToDeclared[d.structure] = d;
351         }
352 
353         // Note: Graph may contain more than only explicitly declared structures.
354         Graph<const TStructure *> g                   = BuildGraph(symbolEnv, structs);
355         std::vector<const TStructure *> sortedStructs = Toposort(g);
356         ASSERT(declaredStructs.size() <= sortedStructs.size());
357 
358         declaredStructs.clear();
359         for (const TStructure *s : sortedStructs)
360         {
361             auto it = rawToDeclared.find(s);
362             if (it != rawToDeclared.end())
363             {
364                 auto &d = it->second;
365                 ASSERT(d.declNode);
366                 declaredStructs.push_back(d);
367             }
368         }
369     }
370 
371     {
372         TIntermSequence newStmtNodes;
373         std::unordered_map<const TStructure *, const TFunction *> equalityFunctions;
374         for (auto &[declNode, structure] : declaredStructs)
375         {
376             newStmtNodes.push_back(declNode);
377             if (usedStructs.find(structure) != usedStructs.end())
378             {
379                 TIntermFunctionDefinition *eq = CreateStructEqualityFunction(
380                     compiler.getSymbolTable(), *structure, equalityFunctions);
381                 newStmtNodes.push_back(eq);
382                 equalityFunctions[structure] = eq->getFunction();
383             }
384         }
385 
386         for (TIntermNode *stmtNode : nonStructStmtNodes)
387         {
388             ASSERT(stmtNode);
389             newStmtNodes.push_back(stmtNode);
390         }
391 
392         *root.getSequence() = newStmtNodes;
393     }
394 
395     return compiler.validateAST(&root);
396 }
397