1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8 #include <functional>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <vector>
12
13 #include "compiler/translator/ImmutableStringBuilder.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/ToposortStructs.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18
19 using namespace sh;
20
21 ////////////////////////////////////////////////////////////////////////////////
22
23 namespace
24 {
25
26 template <typename T>
27 using Edges = std::unordered_set<T>;
28
29 template <typename T>
30 using Graph = std::unordered_map<T, Edges<T>>;
31
32 struct EdgeComparator
33 {
operator ()__anona131e7250111::EdgeComparator34 bool operator()(const TStructure *s1, const TStructure *s2) { return s2->name() < s1->name(); }
35 };
36
BuildGraphImpl(SymbolEnv & symbolEnv,Graph<const TStructure * > & g,const TStructure * s)37 void BuildGraphImpl(SymbolEnv &symbolEnv, Graph<const TStructure *> &g, const TStructure *s)
38 {
39 if (g.find(s) != g.end())
40 {
41 return;
42 }
43
44 Edges<const TStructure *> &es = g[s];
45
46 const TFieldList &fs = s->fields();
47 for (const TField *f : fs)
48 {
49 if (const TStructure *z = symbolEnv.remap(f->type()->getStruct()))
50 {
51 es.insert(z);
52 BuildGraphImpl(symbolEnv, g, z);
53 Edges<const TStructure *> &ez = g[z];
54 es.insert(ez.begin(), ez.end());
55 }
56 }
57 }
58
BuildGraph(SymbolEnv & symbolEnv,const std::vector<const TStructure * > & structs)59 Graph<const TStructure *> BuildGraph(SymbolEnv &symbolEnv,
60 const std::vector<const TStructure *> &structs)
61 {
62 Graph<const TStructure *> g;
63 for (const TStructure *s : structs)
64 {
65 BuildGraphImpl(symbolEnv, g, s);
66 }
67 return g;
68 }
69
SortEdges(const std::unordered_set<const TStructure * > & structs)70 std::vector<const TStructure *> SortEdges(const std::unordered_set<const TStructure *> &structs)
71 {
72 std::vector<const TStructure *> sorted;
73 sorted.reserve(structs.size());
74 sorted.insert(sorted.begin(), structs.begin(), structs.end());
75 std::sort(sorted.begin(), sorted.end(), EdgeComparator());
76 return sorted;
77 }
78
79 // Algorthm: https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
80 // Note that the algorithm is modified to visit nodes in sorted order. This
81 // ensures consistent results. Without this, the returned order (in so far as
82 // leaf nodes) is undefined, because iterating over an unordered_set of pointers
83 // depends upon the actual pointer values. Consistent results is important for
84 // code that keys off the string of shaders for caching.
85 template <typename T>
Toposort(const Graph<T> & g)86 std::vector<T> Toposort(const Graph<T> &g)
87 {
88 // nodes with temporary mark
89 std::unordered_set<T> temps;
90
91 // nodes without permanent mark
92 std::unordered_set<T> invPerms;
93 for (const auto &entry : g)
94 {
95 invPerms.insert(entry.first);
96 }
97
98 // L <- Empty list that will contain the sorted elements
99 std::vector<T> L;
100
101 // function visit(node n)
102 std::function<void(T)> visit = [&](T n) -> void {
103 // if n has a permanent mark then
104 if (invPerms.find(n) == invPerms.end())
105 {
106 // return
107 return;
108 }
109 // if n has a temporary mark then
110 if (temps.find(n) != temps.end())
111 {
112 // stop (not a DAG)
113 UNREACHABLE();
114 }
115
116 // mark n with a temporary mark
117 temps.insert(n);
118
119 // for each node m with an edge from n to m do
120 auto enIter = g.find(n);
121 ASSERT(enIter != g.end());
122
123 std::vector<T> sorted = SortEdges(enIter->second);
124 for (T m : sorted)
125 {
126 // visit(m)
127 visit(m);
128 }
129
130 // remove temporary mark from n
131 temps.erase(n);
132 // mark n with a permanent mark
133 invPerms.erase(n);
134 // add n to head of L
135 L.push_back(n);
136 };
137
138 // while exists nodes without a permanent mark do
139 while (!invPerms.empty())
140 {
141 // select an unmarked node n
142 std::vector<T> sorted = SortEdges(invPerms);
143 T n = *sorted.begin();
144 // visit(n)
145 visit(n);
146 }
147
148 return L;
149 }
150
CreateStructEqualityFunction(TSymbolTable & symbolTable,const TStructure & aStructType,const std::unordered_map<const TStructure *,const TFunction * > & equalityFunctions)151 TIntermFunctionDefinition *CreateStructEqualityFunction(
152 TSymbolTable &symbolTable,
153 const TStructure &aStructType,
154 const std::unordered_map<const TStructure *, const TFunction *> &equalityFunctions)
155 {
156 auto &funcEquality =
157 *new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
158 new TType(TBasicType::EbtBool), true);
159
160 auto &aStruct = CreateInstanceVariable(symbolTable, aStructType, Name("a"));
161 auto &bStruct = CreateInstanceVariable(symbolTable, aStructType, Name("b"));
162 funcEquality.addParameter(&aStruct);
163 funcEquality.addParameter(&bStruct);
164
165 auto &bodyEquality = *new TIntermBlock();
166 std::vector<TIntermTyped *> andNodes;
167
168 const TFieldList &aFields = aStructType.fields();
169 const size_t size = aFields.size();
170
171 auto testEquality = [&](TIntermTyped &a, TIntermTyped &b) -> TIntermTyped * {
172 ASSERT(a.getType() == b.getType());
173 const TType &type = a.getType();
174 if (const TStructure *structure = type.getStruct(); structure != nullptr)
175 {
176 auto func = equalityFunctions.find(structure);
177 if (func != equalityFunctions.end())
178 {
179 return TIntermAggregate::CreateFunctionCall(*func->second,
180 new TIntermSequence{&a, &b});
181 }
182 UNREACHABLE();
183 }
184 return new TIntermBinary(TOperator::EOpEqual, &a, &b);
185 };
186
187 for (size_t idx = 0; idx < size; ++idx)
188 {
189 const TField &aField = *aFields[idx];
190 const TType &aFieldType = *aField.type();
191 const Name aFieldName(aField);
192
193 if (aFieldType.isArray())
194 {
195 ASSERT(!aFieldType.isArrayOfArrays()); // TODO
196 int dim = aFieldType.getOutermostArraySize();
197 for (int d = 0; d < dim; ++d)
198 {
199 auto &aAccess = AccessIndex(AccessField(aStruct, aFieldName), d);
200 auto &bAccess = AccessIndex(AccessField(bStruct, aFieldName), d);
201 auto *eqNode = testEquality(bAccess, aAccess);
202 andNodes.push_back(eqNode);
203 }
204 }
205 else
206 {
207 auto &aAccess = AccessField(aStruct, aFieldName);
208 auto &bAccess = AccessField(bStruct, aFieldName);
209 auto *eqNode = testEquality(bAccess, aAccess);
210 andNodes.push_back(eqNode);
211 }
212 }
213
214 ASSERT(andNodes.size() > 0); // Empty structs are not allowed in GLSL
215 TIntermTyped *outNode = andNodes.back();
216 andNodes.pop_back();
217 for (TIntermTyped *andNode : andNodes)
218 {
219 outNode = new TIntermBinary(TOperator::EOpLogicalAnd, andNode, outNode);
220 }
221 bodyEquality.appendStatement(new TIntermBranch(TOperator::EOpReturn, outNode));
222 auto *funcProtoEquality = new TIntermFunctionPrototype(&funcEquality);
223 return new TIntermFunctionDefinition(funcProtoEquality, &bodyEquality);
224 }
225
226 struct DeclaredStructure
227 {
228 TIntermDeclaration *declNode;
229 const TStructure *structure;
230 };
231
GetAsDeclaredStructure(SymbolEnv & symbolEnv,TIntermNode & node,DeclaredStructure & out)232 bool GetAsDeclaredStructure(SymbolEnv &symbolEnv, TIntermNode &node, DeclaredStructure &out)
233 {
234 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
235 {
236 ASSERT(declNode->getChildCount() == 1);
237 TIntermNode &childNode = *declNode->getChildNode(0);
238
239 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
240 {
241 const TVariable &var = symbolNode->variable();
242 const TType &type = var.getType();
243 if (const TStructure *structure = symbolEnv.remap(type.getStruct()))
244 {
245 if (type.isStructSpecifier())
246 {
247 out.declNode = declNode;
248 out.structure = structure;
249 return true;
250 }
251 }
252 }
253 }
254 return false;
255 }
256
257 class FindStructEqualityUse : public TIntermTraverser
258 {
259 public:
260 SymbolEnv &mSymbolEnv;
261 std::unordered_set<const TStructure *> mUsedStructs;
262
FindStructEqualityUse(SymbolEnv & symbolEnv)263 FindStructEqualityUse(SymbolEnv &symbolEnv)
264 : TIntermTraverser(false, false, true), mSymbolEnv(symbolEnv)
265 {}
266
visitBinary(Visit,TIntermBinary * binary)267 bool visitBinary(Visit, TIntermBinary *binary) override
268 {
269 const TOperator op = binary->getOp();
270
271 switch (op)
272 {
273 case TOperator::EOpEqual:
274 case TOperator::EOpNotEqual:
275 {
276 const TType &leftType = binary->getLeft()->getType();
277 const TType &rightType = binary->getRight()->getType();
278 ASSERT(leftType.getStruct() == rightType.getStruct());
279 if (const TStructure *structure = mSymbolEnv.remap(leftType.getStruct()))
280 {
281 useStruct(*structure);
282 }
283 }
284 break;
285
286 default:
287 break;
288 }
289
290 return true;
291 }
292
293 private:
useStruct(const TStructure & structure)294 void useStruct(const TStructure &structure)
295 {
296 if (mUsedStructs.insert(&structure).second)
297 {
298 for (const TField *field : structure.fields())
299 {
300 if (const TStructure *subStruct = mSymbolEnv.remap(field->type()->getStruct()))
301 {
302 useStruct(*subStruct);
303 }
304 }
305 }
306 }
307 };
308
309 } // anonymous namespace
310
311 ////////////////////////////////////////////////////////////////////////////////
312
ToposortStructs(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root,ProgramPreludeConfig & ppc)313 bool sh::ToposortStructs(TCompiler &compiler,
314 SymbolEnv &symbolEnv,
315 TIntermBlock &root,
316 ProgramPreludeConfig &ppc)
317 {
318 FindStructEqualityUse finder(symbolEnv);
319 root.traverse(&finder);
320 auto &usedStructs = finder.mUsedStructs;
321
322 std::vector<DeclaredStructure> declaredStructs;
323 std::vector<TIntermNode *> nonStructStmtNodes;
324
325 {
326 DeclaredStructure declaredStruct;
327 const size_t stmtCount = root.getChildCount();
328 for (size_t i = 0; i < stmtCount; ++i)
329 {
330 TIntermNode &stmtNode = *root.getChildNode(i);
331 if (GetAsDeclaredStructure(symbolEnv, stmtNode, declaredStruct))
332 {
333 declaredStructs.push_back(declaredStruct);
334 }
335 else
336 {
337 nonStructStmtNodes.push_back(&stmtNode);
338 }
339 }
340 }
341
342 {
343 std::vector<const TStructure *> structs;
344 std::unordered_map<const TStructure *, DeclaredStructure> rawToDeclared;
345
346 for (const DeclaredStructure &d : declaredStructs)
347 {
348 structs.push_back(d.structure);
349 ASSERT(rawToDeclared.find(d.structure) == rawToDeclared.end());
350 rawToDeclared[d.structure] = d;
351 }
352
353 // Note: Graph may contain more than only explicitly declared structures.
354 Graph<const TStructure *> g = BuildGraph(symbolEnv, structs);
355 std::vector<const TStructure *> sortedStructs = Toposort(g);
356 ASSERT(declaredStructs.size() <= sortedStructs.size());
357
358 declaredStructs.clear();
359 for (const TStructure *s : sortedStructs)
360 {
361 auto it = rawToDeclared.find(s);
362 if (it != rawToDeclared.end())
363 {
364 auto &d = it->second;
365 ASSERT(d.declNode);
366 declaredStructs.push_back(d);
367 }
368 }
369 }
370
371 {
372 TIntermSequence newStmtNodes;
373 std::unordered_map<const TStructure *, const TFunction *> equalityFunctions;
374 for (auto &[declNode, structure] : declaredStructs)
375 {
376 newStmtNodes.push_back(declNode);
377 if (usedStructs.find(structure) != usedStructs.end())
378 {
379 TIntermFunctionDefinition *eq = CreateStructEqualityFunction(
380 compiler.getSymbolTable(), *structure, equalityFunctions);
381 newStmtNodes.push_back(eq);
382 equalityFunctions[structure] = eq->getFunction();
383 }
384 }
385
386 for (TIntermNode *stmtNode : nonStructStmtNodes)
387 {
388 ASSERT(stmtNode);
389 newStmtNodes.push_back(stmtNode);
390 }
391
392 *root.getSequence() = newStmtNodes;
393 }
394
395 return compiler.validateAST(&root);
396 }
397