1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <unordered_map>
8
9 #include "common/system_utils.h"
10 #include "compiler/translator/IntermRebuild.h"
11 #include "compiler/translator/tree_ops/msl/SeparateCompoundExpressions.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 #include "compiler/translator/util.h"
14
15 using namespace sh;
16
17 ////////////////////////////////////////////////////////////////////////////////
18
19 namespace
20 {
21
IsIndex(TOperator op)22 bool IsIndex(TOperator op)
23 {
24 switch (op)
25 {
26 case TOperator::EOpIndexDirect:
27 case TOperator::EOpIndexDirectInterfaceBlock:
28 case TOperator::EOpIndexDirectStruct:
29 case TOperator::EOpIndexIndirect:
30 return true;
31 default:
32 return false;
33 }
34 }
35
IsIndex(TIntermTyped & expr)36 bool IsIndex(TIntermTyped &expr)
37 {
38 if (auto *binary = expr.getAsBinaryNode())
39 {
40 return IsIndex(binary->getOp());
41 }
42 return expr.getAsSwizzleNode();
43 }
44
IsCompoundAssignment(TOperator op)45 bool IsCompoundAssignment(TOperator op)
46 {
47 switch (op)
48 {
49 case EOpAddAssign:
50 case EOpSubAssign:
51 case EOpMulAssign:
52 case EOpVectorTimesMatrixAssign:
53 case EOpVectorTimesScalarAssign:
54 case EOpMatrixTimesScalarAssign:
55 case EOpMatrixTimesMatrixAssign:
56 case EOpDivAssign:
57 case EOpIModAssign:
58 case EOpBitShiftLeftAssign:
59 case EOpBitShiftRightAssign:
60 case EOpBitwiseAndAssign:
61 case EOpBitwiseXorAssign:
62 case EOpBitwiseOrAssign:
63 return true;
64 default:
65 return false;
66 }
67 }
68
ViewBinaryChain(TOperator op,TIntermTyped & node,std::vector<TIntermTyped * > & out)69 bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
70 {
71 TIntermBinary *binary = node.getAsBinaryNode();
72 if (!binary || binary->getOp() != op)
73 {
74 return false;
75 }
76
77 TIntermTyped *left = binary->getLeft();
78 TIntermTyped *right = binary->getRight();
79
80 if (!ViewBinaryChain(op, *left, out))
81 {
82 out.push_back(left);
83 }
84
85 if (!ViewBinaryChain(op, *right, out))
86 {
87 out.push_back(right);
88 }
89
90 return true;
91 }
92
ViewBinaryChain(TIntermBinary & node)93 std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
94 {
95 std::vector<TIntermTyped *> chain;
96 ViewBinaryChain(node.getOp(), node, chain);
97 ASSERT(chain.size() >= 2);
98 return chain;
99 }
100
101 class PrePass : public TIntermRebuild
102 {
103 public:
PrePass(TCompiler & compiler)104 PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
105
106 private:
107 // Change chains of
108 // x OP y OP z
109 // to
110 // x OP (y OP z)
111 // regardless of original parenthesization.
reassociateRight(TIntermBinary & node)112 TIntermTyped &reassociateRight(TIntermBinary &node)
113 {
114 const TOperator op = node.getOp();
115 std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
116
117 TIntermTyped *result = chain.back();
118 chain.pop_back();
119 ASSERT(result);
120
121 const auto begin = chain.rbegin();
122 const auto end = chain.rend();
123
124 for (auto iter = begin; iter != end; ++iter)
125 {
126 TIntermTyped *part = *iter;
127 ASSERT(part);
128 TIntermNode *temp = rebuild(*part).single();
129 ASSERT(temp);
130 part = temp->getAsTyped();
131 ASSERT(part);
132 result = new TIntermBinary(op, part, result);
133 }
134 return *result;
135 }
136
137 private:
visitBinaryPre(TIntermBinary & node)138 PreResult visitBinaryPre(TIntermBinary &node) override
139 {
140 const TOperator op = node.getOp();
141 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
142 {
143 return {reassociateRight(node), VisitBits::Neither};
144 }
145 return node;
146 }
147 };
148
149 class Separator : public TIntermRebuild
150 {
151 IdGen &mIdGen;
152 std::vector<std::vector<TIntermNode *>> mStmtsStack;
153 std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
154 std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
155 std::unordered_set<TIntermDeclaration *> mMaskedDecls;
156
157 public:
Separator(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen)158 Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
159 : TIntermRebuild(compiler, true, true), mIdGen(idGen)
160 {}
161
~Separator()162 ~Separator() override
163 {
164 ASSERT(mStmtsStack.empty());
165 ASSERT(mExprMap.empty());
166 ASSERT(mBindingMapStack.empty());
167 }
168
169 private:
getCurrStmts()170 std::vector<TIntermNode *> &getCurrStmts()
171 {
172 ASSERT(!mStmtsStack.empty());
173 return mStmtsStack.back();
174 }
175
getCurrBindingMap()176 std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
177 {
178 ASSERT(!mBindingMapStack.empty());
179 return mBindingMapStack.back();
180 }
181
pushStmt(TIntermNode & node)182 void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
183
isTerminalExpr(TIntermNode & node)184 bool isTerminalExpr(TIntermNode &node)
185 {
186 NodeType nodeType = getNodeType(node);
187 switch (nodeType)
188 {
189 case NodeType::Symbol:
190 case NodeType::ConstantUnion:
191 return true;
192 default:
193 return false;
194 }
195 }
196
pullMappedExpr(TIntermTyped * node,bool allowBacktrack)197 TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
198 {
199 TIntermTyped *expr;
200
201 {
202 auto iter = mExprMap.find(node);
203 if (iter == mExprMap.end())
204 {
205 return node;
206 }
207 ASSERT(node);
208 expr = iter->second;
209 ASSERT(expr);
210 mExprMap.erase(iter);
211 }
212
213 if (allowBacktrack)
214 {
215 auto &bindingMap = getCurrBindingMap();
216 while (TIntermSymbol *symbol = expr->getAsSymbolNode())
217 {
218 const TVariable &var = symbol->variable();
219 auto iter = bindingMap.find(&var);
220 if (iter == bindingMap.end())
221 {
222 return expr;
223 }
224 ASSERT(var.symbolType() == SymbolType::AngleInternal);
225 TIntermDeclaration *decl = iter->second;
226 ASSERT(decl);
227 expr = ViewDeclaration(*decl).initExpr;
228 ASSERT(expr);
229 bindingMap.erase(iter);
230 mMaskedDecls.insert(decl);
231 }
232 }
233
234 return expr;
235 }
236
isStandaloneExpr(TIntermTyped & expr)237 bool isStandaloneExpr(TIntermTyped &expr)
238 {
239 if (getParentNode()->getAsBlock())
240 {
241 return true;
242 }
243 // https://bugs.webkit.org/show_bug.cgi?id=227723: Fix for sequence operator.
244 if ((expr.getType().getBasicType() == TBasicType::EbtVoid))
245 {
246 return true;
247 }
248 return false;
249 }
250
pushBinding(TIntermTyped & oldExpr,TIntermTyped & newExpr)251 void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
252 {
253 if (isStandaloneExpr(newExpr))
254 {
255 pushStmt(newExpr);
256 return;
257 }
258 if (IsIndex(newExpr))
259 {
260 mExprMap[&oldExpr] = &newExpr;
261 return;
262 }
263 auto &bindingMap = getCurrBindingMap();
264 auto *var = CreateTempVariable(&mSymbolTable, &newExpr.getType(), EvqTemporary);
265 auto *decl = new TIntermDeclaration(var, &newExpr);
266 pushStmt(*decl);
267 mExprMap[&oldExpr] = new TIntermSymbol(var);
268 bindingMap[var] = decl;
269 }
270
pushStacks()271 void pushStacks()
272 {
273 mStmtsStack.emplace_back();
274 mBindingMapStack.emplace_back();
275 }
276
popStacks()277 void popStacks()
278 {
279 ASSERT(!mBindingMapStack.empty());
280 ASSERT(!mStmtsStack.empty());
281 ASSERT(mStmtsStack.back().empty());
282 mBindingMapStack.pop_back();
283 mStmtsStack.pop_back();
284 }
285
pushStmtsIntoBlock(TIntermBlock & block,std::vector<TIntermNode * > & stmts)286 void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
287 {
288 TIntermSequence &seq = *block.getSequence();
289 for (TIntermNode *stmt : stmts)
290 {
291 if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
292 {
293 auto iter = mMaskedDecls.find(decl);
294 if (iter != mMaskedDecls.end())
295 {
296 mMaskedDecls.erase(iter);
297 continue;
298 }
299 }
300 seq.push_back(stmt);
301 }
302 }
303
buildBlockWithTailAssign(const TVariable & var,TIntermTyped & newExpr)304 TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
305 {
306 std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
307 popStacks();
308
309 auto &block = *new TIntermBlock();
310 auto &seq = *block.getSequence();
311 seq.reserve(1 + stmts.size());
312 pushStmtsIntoBlock(block, stmts);
313 seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
314
315 return block;
316 }
317
318 private:
visitBlockPre(TIntermBlock & node)319 PreResult visitBlockPre(TIntermBlock &node) override
320 {
321 pushStacks();
322 return node;
323 }
324
visitBlockPost(TIntermBlock & node)325 PostResult visitBlockPost(TIntermBlock &node) override
326 {
327 std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
328 popStacks();
329
330 TIntermSequence &seq = *node.getSequence();
331 seq.clear();
332 seq.reserve(stmts.size());
333 pushStmtsIntoBlock(node, stmts);
334
335 TIntermNode *parent = getParentNode();
336 if (parent && parent->getAsBlock())
337 {
338 pushStmt(node);
339 }
340
341 return node;
342 }
343
visitDeclarationPre(TIntermDeclaration & node)344 PreResult visitDeclarationPre(TIntermDeclaration &node) override
345 {
346 Declaration decl = ViewDeclaration(node);
347 if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
348 {
349 pushStmt(node);
350 return {node, VisitBits::Neither};
351 }
352 return node;
353 }
354
visitDeclarationPost(TIntermDeclaration & node)355 PostResult visitDeclarationPost(TIntermDeclaration &node) override
356 {
357 Declaration decl = ViewDeclaration(node);
358 ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
359 ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
360
361 TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
362 if (decl.initExpr == newInitExpr)
363 {
364 pushStmt(node);
365 }
366 else
367 {
368 auto &newNode = *new TIntermDeclaration();
369 newNode.appendDeclarator(
370 new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
371 pushStmt(newNode);
372 }
373 return node;
374 }
375
visitUnaryPost(TIntermUnary & node)376 PostResult visitUnaryPost(TIntermUnary &node) override
377 {
378 TIntermTyped *expr = node.getOperand();
379 TIntermTyped *newExpr = pullMappedExpr(expr, false);
380 if (expr == newExpr)
381 {
382 pushBinding(node, node);
383 }
384 else
385 {
386 pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
387 }
388 return node;
389 }
390
visitBinaryPre(TIntermBinary & node)391 PreResult visitBinaryPre(TIntermBinary &node) override
392 {
393 const TOperator op = node.getOp();
394 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
395 {
396 TIntermTyped *left = node.getLeft();
397 TIntermTyped *right = node.getRight();
398
399 PostResult leftResult = rebuild(*left);
400 ASSERT(leftResult.single());
401
402 pushStacks();
403 PostResult rightResult = rebuild(*right);
404 ASSERT(rightResult.single());
405
406 return {node, VisitBits::Post};
407 }
408
409 return node;
410 }
411
visitBinaryPost(TIntermBinary & node)412 PostResult visitBinaryPost(TIntermBinary &node) override
413 {
414 const TOperator op = node.getOp();
415 if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
416 {
417 // Special case is handled by visitDeclarationPost
418 return node;
419 }
420
421 TIntermTyped *left = node.getLeft();
422 TIntermTyped *right = node.getRight();
423
424 if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
425 {
426 const Name name = mIdGen.createNewName();
427 auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
428 name.symbolType());
429
430 TIntermTyped *newRight = pullMappedExpr(right, true);
431 TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
432 TIntermTyped *newLeft = pullMappedExpr(left, true);
433
434 TIntermTyped *cond = new TIntermSymbol(var);
435 if (op == TOperator::EOpLogicalOr)
436 {
437 cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
438 }
439
440 pushStmt(*new TIntermDeclaration(var, newLeft));
441 pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
442 if (!isStandaloneExpr(node))
443 {
444 mExprMap[&node] = new TIntermSymbol(var);
445 }
446
447 return node;
448 }
449
450 const bool isAssign = IsAssignment(op);
451 const bool isCompoundAssign = IsCompoundAssignment(op);
452 TIntermTyped *newLeft = pullMappedExpr(left, false);
453 TIntermTyped *newRight = pullMappedExpr(right, isAssign && !isCompoundAssign);
454 if (op == TOperator::EOpComma)
455 {
456 pushBinding(node, *newRight);
457 return node;
458 }
459 else
460 {
461 TIntermBinary *newNode;
462 if (left == newLeft && right == newRight)
463 {
464 newNode = &node;
465 }
466 else
467 {
468 newNode = new TIntermBinary(op, newLeft, newRight);
469 }
470 pushBinding(node, *newNode);
471 return node;
472 }
473 }
474
visitTernaryPre(TIntermTernary & node)475 PreResult visitTernaryPre(TIntermTernary &node) override
476 {
477 PostResult condResult = rebuild(*node.getCondition());
478 ASSERT(condResult.single());
479
480 pushStacks();
481 PostResult thenResult = rebuild(*node.getTrueExpression());
482 ASSERT(thenResult.single());
483
484 pushStacks();
485 PostResult elseResult = rebuild(*node.getFalseExpression());
486 ASSERT(elseResult.single());
487
488 return {node, VisitBits::Post};
489 }
490
visitTernaryPost(TIntermTernary & node)491 PostResult visitTernaryPost(TIntermTernary &node) override
492 {
493 TIntermTyped *cond = node.getCondition();
494 TIntermTyped *then = node.getTrueExpression();
495 TIntermTyped *else_ = node.getFalseExpression();
496
497 auto *var = CreateTempVariable(&mSymbolTable, &node.getType(), EvqTemporary);
498 TIntermTyped *newElse = pullMappedExpr(else_, false);
499 TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
500 TIntermTyped *newThen = pullMappedExpr(then, true);
501 TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
502 TIntermTyped *newCond = pullMappedExpr(cond, true);
503
504 pushStmt(*new TIntermDeclaration{var});
505 pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
506 if (!isStandaloneExpr(node))
507 {
508 mExprMap[&node] = new TIntermSymbol(var);
509 }
510
511 return node;
512 }
513
visitSwizzlePost(TIntermSwizzle & node)514 PostResult visitSwizzlePost(TIntermSwizzle &node) override
515 {
516 TIntermTyped *expr = node.getOperand();
517 TIntermTyped *newExpr = pullMappedExpr(expr, false);
518 if (expr == newExpr)
519 {
520 pushBinding(node, node);
521 }
522 else
523 {
524 pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
525 }
526 return node;
527 }
528
visitAggregatePost(TIntermAggregate & node)529 PostResult visitAggregatePost(TIntermAggregate &node) override
530 {
531 TIntermSequence &args = *node.getSequence();
532 for (TIntermNode *&arg : args)
533 {
534 TIntermTyped *targ = arg->getAsTyped();
535 ASSERT(targ);
536 arg = pullMappedExpr(targ, false);
537 }
538 pushBinding(node, node);
539 return node;
540 }
541
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)542 PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
543 {
544 pushStmt(node);
545 return node;
546 }
547
visitFunctionPrototypePost(TIntermFunctionPrototype & node)548 PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
549 {
550 if (!getParentFunction())
551 {
552 pushStmt(node);
553 }
554 return node;
555 }
556
visitCasePre(TIntermCase & node)557 PreResult visitCasePre(TIntermCase &node) override
558 {
559 if (TIntermTyped *cond = node.getCondition())
560 {
561 ASSERT(isTerminalExpr(*cond));
562 }
563 pushStmt(node);
564 return {node, VisitBits::Neither};
565 }
566
visitSwitchPost(TIntermSwitch & node)567 PostResult visitSwitchPost(TIntermSwitch &node) override
568 {
569 TIntermTyped *init = node.getInit();
570 TIntermTyped *newInit = pullMappedExpr(init, false);
571 if (init == newInit)
572 {
573 pushStmt(node);
574 }
575 else
576 {
577 pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
578 }
579
580 return node;
581 }
582
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)583 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
584 {
585 pushStmt(node);
586 return node;
587 }
588
visitIfElsePost(TIntermIfElse & node)589 PostResult visitIfElsePost(TIntermIfElse &node) override
590 {
591 TIntermTyped *cond = node.getCondition();
592 TIntermTyped *newCond = pullMappedExpr(cond, false);
593 if (cond == newCond)
594 {
595 pushStmt(node);
596 }
597 else
598 {
599 pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
600 }
601 return node;
602 }
603
visitBranchPost(TIntermBranch & node)604 PostResult visitBranchPost(TIntermBranch &node) override
605 {
606 TIntermTyped *expr = node.getExpression();
607 TIntermTyped *newExpr = pullMappedExpr(expr, false);
608 if (expr == newExpr)
609 {
610 pushStmt(node);
611 }
612 else
613 {
614 pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
615 }
616 return node;
617 }
618
visitLoopPre(TIntermLoop & node)619 PreResult visitLoopPre(TIntermLoop &node) override
620 {
621 if (!rebuildInPlace(*node.getBody()))
622 {
623 UNREACHABLE();
624 }
625 pushStmt(node);
626 return {node, VisitBits::Neither};
627 }
628
visitConstantUnionPost(TIntermConstantUnion & node)629 PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
630 {
631 const TType &type = node.getType();
632 if (!type.isScalar() && !type.isVector() && !type.isMatrix())
633 {
634 pushBinding(node, node);
635 }
636 return node;
637 }
638
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)639 PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
640 {
641 // With the removal of RewriteGlobalQualifierDecls, we may encounter globals while
642 // seperating compound expressions.
643 pushStmt(node);
644 return node;
645 }
646 };
647
648 } // anonymous namespace
649
650 ////////////////////////////////////////////////////////////////////////////////
651
SeparateCompoundExpressions(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,TIntermBlock & root)652 bool sh::SeparateCompoundExpressions(TCompiler &compiler,
653 SymbolEnv &symbolEnv,
654 IdGen &idGen,
655 TIntermBlock &root)
656 {
657 if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
658 {
659 return true;
660 }
661
662 if (!PrePass(compiler).rebuildRoot(root))
663 {
664 return false;
665 }
666
667 if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
668 {
669 return false;
670 }
671
672 return true;
673 }
674