xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/msl/SeparateCompoundExpressions.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <unordered_map>
8 
9 #include "common/system_utils.h"
10 #include "compiler/translator/IntermRebuild.h"
11 #include "compiler/translator/tree_ops/msl/SeparateCompoundExpressions.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 #include "compiler/translator/util.h"
14 
15 using namespace sh;
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 namespace
20 {
21 
IsIndex(TOperator op)22 bool IsIndex(TOperator op)
23 {
24     switch (op)
25     {
26         case TOperator::EOpIndexDirect:
27         case TOperator::EOpIndexDirectInterfaceBlock:
28         case TOperator::EOpIndexDirectStruct:
29         case TOperator::EOpIndexIndirect:
30             return true;
31         default:
32             return false;
33     }
34 }
35 
IsIndex(TIntermTyped & expr)36 bool IsIndex(TIntermTyped &expr)
37 {
38     if (auto *binary = expr.getAsBinaryNode())
39     {
40         return IsIndex(binary->getOp());
41     }
42     return expr.getAsSwizzleNode();
43 }
44 
IsCompoundAssignment(TOperator op)45 bool IsCompoundAssignment(TOperator op)
46 {
47     switch (op)
48     {
49         case EOpAddAssign:
50         case EOpSubAssign:
51         case EOpMulAssign:
52         case EOpVectorTimesMatrixAssign:
53         case EOpVectorTimesScalarAssign:
54         case EOpMatrixTimesScalarAssign:
55         case EOpMatrixTimesMatrixAssign:
56         case EOpDivAssign:
57         case EOpIModAssign:
58         case EOpBitShiftLeftAssign:
59         case EOpBitShiftRightAssign:
60         case EOpBitwiseAndAssign:
61         case EOpBitwiseXorAssign:
62         case EOpBitwiseOrAssign:
63             return true;
64         default:
65             return false;
66     }
67 }
68 
ViewBinaryChain(TOperator op,TIntermTyped & node,std::vector<TIntermTyped * > & out)69 bool ViewBinaryChain(TOperator op, TIntermTyped &node, std::vector<TIntermTyped *> &out)
70 {
71     TIntermBinary *binary = node.getAsBinaryNode();
72     if (!binary || binary->getOp() != op)
73     {
74         return false;
75     }
76 
77     TIntermTyped *left  = binary->getLeft();
78     TIntermTyped *right = binary->getRight();
79 
80     if (!ViewBinaryChain(op, *left, out))
81     {
82         out.push_back(left);
83     }
84 
85     if (!ViewBinaryChain(op, *right, out))
86     {
87         out.push_back(right);
88     }
89 
90     return true;
91 }
92 
ViewBinaryChain(TIntermBinary & node)93 std::vector<TIntermTyped *> ViewBinaryChain(TIntermBinary &node)
94 {
95     std::vector<TIntermTyped *> chain;
96     ViewBinaryChain(node.getOp(), node, chain);
97     ASSERT(chain.size() >= 2);
98     return chain;
99 }
100 
101 class PrePass : public TIntermRebuild
102 {
103   public:
PrePass(TCompiler & compiler)104     PrePass(TCompiler &compiler) : TIntermRebuild(compiler, true, true) {}
105 
106   private:
107     // Change chains of
108     //      x OP y OP z
109     // to
110     //      x OP (y OP z)
111     // regardless of original parenthesization.
reassociateRight(TIntermBinary & node)112     TIntermTyped &reassociateRight(TIntermBinary &node)
113     {
114         const TOperator op                = node.getOp();
115         std::vector<TIntermTyped *> chain = ViewBinaryChain(node);
116 
117         TIntermTyped *result = chain.back();
118         chain.pop_back();
119         ASSERT(result);
120 
121         const auto begin = chain.rbegin();
122         const auto end   = chain.rend();
123 
124         for (auto iter = begin; iter != end; ++iter)
125         {
126             TIntermTyped *part = *iter;
127             ASSERT(part);
128             TIntermNode *temp = rebuild(*part).single();
129             ASSERT(temp);
130             part = temp->getAsTyped();
131             ASSERT(part);
132             result = new TIntermBinary(op, part, result);
133         }
134         return *result;
135     }
136 
137   private:
visitBinaryPre(TIntermBinary & node)138     PreResult visitBinaryPre(TIntermBinary &node) override
139     {
140         const TOperator op = node.getOp();
141         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
142         {
143             return {reassociateRight(node), VisitBits::Neither};
144         }
145         return node;
146     }
147 };
148 
149 class Separator : public TIntermRebuild
150 {
151     IdGen &mIdGen;
152     std::vector<std::vector<TIntermNode *>> mStmtsStack;
153     std::vector<std::unordered_map<const TVariable *, TIntermDeclaration *>> mBindingMapStack;
154     std::unordered_map<TIntermTyped *, TIntermTyped *> mExprMap;
155     std::unordered_set<TIntermDeclaration *> mMaskedDecls;
156 
157   public:
Separator(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen)158     Separator(TCompiler &compiler, SymbolEnv &symbolEnv, IdGen &idGen)
159         : TIntermRebuild(compiler, true, true), mIdGen(idGen)
160     {}
161 
~Separator()162     ~Separator() override
163     {
164         ASSERT(mStmtsStack.empty());
165         ASSERT(mExprMap.empty());
166         ASSERT(mBindingMapStack.empty());
167     }
168 
169   private:
getCurrStmts()170     std::vector<TIntermNode *> &getCurrStmts()
171     {
172         ASSERT(!mStmtsStack.empty());
173         return mStmtsStack.back();
174     }
175 
getCurrBindingMap()176     std::unordered_map<const TVariable *, TIntermDeclaration *> &getCurrBindingMap()
177     {
178         ASSERT(!mBindingMapStack.empty());
179         return mBindingMapStack.back();
180     }
181 
pushStmt(TIntermNode & node)182     void pushStmt(TIntermNode &node) { getCurrStmts().push_back(&node); }
183 
isTerminalExpr(TIntermNode & node)184     bool isTerminalExpr(TIntermNode &node)
185     {
186         NodeType nodeType = getNodeType(node);
187         switch (nodeType)
188         {
189             case NodeType::Symbol:
190             case NodeType::ConstantUnion:
191                 return true;
192             default:
193                 return false;
194         }
195     }
196 
pullMappedExpr(TIntermTyped * node,bool allowBacktrack)197     TIntermTyped *pullMappedExpr(TIntermTyped *node, bool allowBacktrack)
198     {
199         TIntermTyped *expr;
200 
201         {
202             auto iter = mExprMap.find(node);
203             if (iter == mExprMap.end())
204             {
205                 return node;
206             }
207             ASSERT(node);
208             expr = iter->second;
209             ASSERT(expr);
210             mExprMap.erase(iter);
211         }
212 
213         if (allowBacktrack)
214         {
215             auto &bindingMap = getCurrBindingMap();
216             while (TIntermSymbol *symbol = expr->getAsSymbolNode())
217             {
218                 const TVariable &var = symbol->variable();
219                 auto iter            = bindingMap.find(&var);
220                 if (iter == bindingMap.end())
221                 {
222                     return expr;
223                 }
224                 ASSERT(var.symbolType() == SymbolType::AngleInternal);
225                 TIntermDeclaration *decl = iter->second;
226                 ASSERT(decl);
227                 expr = ViewDeclaration(*decl).initExpr;
228                 ASSERT(expr);
229                 bindingMap.erase(iter);
230                 mMaskedDecls.insert(decl);
231             }
232         }
233 
234         return expr;
235     }
236 
isStandaloneExpr(TIntermTyped & expr)237     bool isStandaloneExpr(TIntermTyped &expr)
238     {
239         if (getParentNode()->getAsBlock())
240         {
241             return true;
242         }
243         // https://bugs.webkit.org/show_bug.cgi?id=227723: Fix for sequence operator.
244         if ((expr.getType().getBasicType() == TBasicType::EbtVoid))
245         {
246             return true;
247         }
248         return false;
249     }
250 
pushBinding(TIntermTyped & oldExpr,TIntermTyped & newExpr)251     void pushBinding(TIntermTyped &oldExpr, TIntermTyped &newExpr)
252     {
253         if (isStandaloneExpr(newExpr))
254         {
255             pushStmt(newExpr);
256             return;
257         }
258         if (IsIndex(newExpr))
259         {
260             mExprMap[&oldExpr] = &newExpr;
261             return;
262         }
263         auto &bindingMap = getCurrBindingMap();
264         auto *var        = CreateTempVariable(&mSymbolTable, &newExpr.getType(), EvqTemporary);
265         auto *decl = new TIntermDeclaration(var, &newExpr);
266         pushStmt(*decl);
267         mExprMap[&oldExpr] = new TIntermSymbol(var);
268         bindingMap[var]    = decl;
269     }
270 
pushStacks()271     void pushStacks()
272     {
273         mStmtsStack.emplace_back();
274         mBindingMapStack.emplace_back();
275     }
276 
popStacks()277     void popStacks()
278     {
279         ASSERT(!mBindingMapStack.empty());
280         ASSERT(!mStmtsStack.empty());
281         ASSERT(mStmtsStack.back().empty());
282         mBindingMapStack.pop_back();
283         mStmtsStack.pop_back();
284     }
285 
pushStmtsIntoBlock(TIntermBlock & block,std::vector<TIntermNode * > & stmts)286     void pushStmtsIntoBlock(TIntermBlock &block, std::vector<TIntermNode *> &stmts)
287     {
288         TIntermSequence &seq = *block.getSequence();
289         for (TIntermNode *stmt : stmts)
290         {
291             if (TIntermDeclaration *decl = stmt->getAsDeclarationNode())
292             {
293                 auto iter = mMaskedDecls.find(decl);
294                 if (iter != mMaskedDecls.end())
295                 {
296                     mMaskedDecls.erase(iter);
297                     continue;
298                 }
299             }
300             seq.push_back(stmt);
301         }
302     }
303 
buildBlockWithTailAssign(const TVariable & var,TIntermTyped & newExpr)304     TIntermBlock &buildBlockWithTailAssign(const TVariable &var, TIntermTyped &newExpr)
305     {
306         std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
307         popStacks();
308 
309         auto &block = *new TIntermBlock();
310         auto &seq   = *block.getSequence();
311         seq.reserve(1 + stmts.size());
312         pushStmtsIntoBlock(block, stmts);
313         seq.push_back(new TIntermBinary(TOperator::EOpAssign, new TIntermSymbol(&var), &newExpr));
314 
315         return block;
316     }
317 
318   private:
visitBlockPre(TIntermBlock & node)319     PreResult visitBlockPre(TIntermBlock &node) override
320     {
321         pushStacks();
322         return node;
323     }
324 
visitBlockPost(TIntermBlock & node)325     PostResult visitBlockPost(TIntermBlock &node) override
326     {
327         std::vector<TIntermNode *> stmts = std::move(getCurrStmts());
328         popStacks();
329 
330         TIntermSequence &seq = *node.getSequence();
331         seq.clear();
332         seq.reserve(stmts.size());
333         pushStmtsIntoBlock(node, stmts);
334 
335         TIntermNode *parent = getParentNode();
336         if (parent && parent->getAsBlock())
337         {
338             pushStmt(node);
339         }
340 
341         return node;
342     }
343 
visitDeclarationPre(TIntermDeclaration & node)344     PreResult visitDeclarationPre(TIntermDeclaration &node) override
345     {
346         Declaration decl = ViewDeclaration(node);
347         if (!decl.initExpr || isTerminalExpr(*decl.initExpr))
348         {
349             pushStmt(node);
350             return {node, VisitBits::Neither};
351         }
352         return node;
353     }
354 
visitDeclarationPost(TIntermDeclaration & node)355     PostResult visitDeclarationPost(TIntermDeclaration &node) override
356     {
357         Declaration decl = ViewDeclaration(node);
358         ASSERT(decl.symbol.variable().symbolType() != SymbolType::Empty);
359         ASSERT(!decl.symbol.variable().getType().isStructSpecifier());
360 
361         TIntermTyped *newInitExpr = pullMappedExpr(decl.initExpr, true);
362         if (decl.initExpr == newInitExpr)
363         {
364             pushStmt(node);
365         }
366         else
367         {
368             auto &newNode = *new TIntermDeclaration();
369             newNode.appendDeclarator(
370                 new TIntermBinary(TOperator::EOpInitialize, &decl.symbol, newInitExpr));
371             pushStmt(newNode);
372         }
373         return node;
374     }
375 
visitUnaryPost(TIntermUnary & node)376     PostResult visitUnaryPost(TIntermUnary &node) override
377     {
378         TIntermTyped *expr    = node.getOperand();
379         TIntermTyped *newExpr = pullMappedExpr(expr, false);
380         if (expr == newExpr)
381         {
382             pushBinding(node, node);
383         }
384         else
385         {
386             pushBinding(node, *new TIntermUnary(node.getOp(), newExpr, node.getFunction()));
387         }
388         return node;
389     }
390 
visitBinaryPre(TIntermBinary & node)391     PreResult visitBinaryPre(TIntermBinary &node) override
392     {
393         const TOperator op = node.getOp();
394         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
395         {
396             TIntermTyped *left  = node.getLeft();
397             TIntermTyped *right = node.getRight();
398 
399             PostResult leftResult = rebuild(*left);
400             ASSERT(leftResult.single());
401 
402             pushStacks();
403             PostResult rightResult = rebuild(*right);
404             ASSERT(rightResult.single());
405 
406             return {node, VisitBits::Post};
407         }
408 
409         return node;
410     }
411 
visitBinaryPost(TIntermBinary & node)412     PostResult visitBinaryPost(TIntermBinary &node) override
413     {
414         const TOperator op = node.getOp();
415         if (op == TOperator::EOpInitialize && getParentNode()->getAsDeclarationNode())
416         {
417             // Special case is handled by visitDeclarationPost
418             return node;
419         }
420 
421         TIntermTyped *left  = node.getLeft();
422         TIntermTyped *right = node.getRight();
423 
424         if (op == TOperator::EOpLogicalAnd || op == TOperator::EOpLogicalOr)
425         {
426             const Name name = mIdGen.createNewName();
427             auto *var = new TVariable(&mSymbolTable, name.rawName(), new TType(TBasicType::EbtBool),
428                                       name.symbolType());
429 
430             TIntermTyped *newRight   = pullMappedExpr(right, true);
431             TIntermBlock *rightBlock = &buildBlockWithTailAssign(*var, *newRight);
432             TIntermTyped *newLeft    = pullMappedExpr(left, true);
433 
434             TIntermTyped *cond = new TIntermSymbol(var);
435             if (op == TOperator::EOpLogicalOr)
436             {
437                 cond = new TIntermUnary(TOperator::EOpLogicalNot, cond, nullptr);
438             }
439 
440             pushStmt(*new TIntermDeclaration(var, newLeft));
441             pushStmt(*new TIntermIfElse(cond, rightBlock, nullptr));
442             if (!isStandaloneExpr(node))
443             {
444                 mExprMap[&node] = new TIntermSymbol(var);
445             }
446 
447             return node;
448         }
449 
450         const bool isAssign         = IsAssignment(op);
451         const bool isCompoundAssign = IsCompoundAssignment(op);
452         TIntermTyped *newLeft       = pullMappedExpr(left, false);
453         TIntermTyped *newRight      = pullMappedExpr(right, isAssign && !isCompoundAssign);
454         if (op == TOperator::EOpComma)
455         {
456             pushBinding(node, *newRight);
457             return node;
458         }
459         else
460         {
461             TIntermBinary *newNode;
462             if (left == newLeft && right == newRight)
463             {
464                 newNode = &node;
465             }
466             else
467             {
468                 newNode = new TIntermBinary(op, newLeft, newRight);
469             }
470             pushBinding(node, *newNode);
471             return node;
472         }
473     }
474 
visitTernaryPre(TIntermTernary & node)475     PreResult visitTernaryPre(TIntermTernary &node) override
476     {
477         PostResult condResult = rebuild(*node.getCondition());
478         ASSERT(condResult.single());
479 
480         pushStacks();
481         PostResult thenResult = rebuild(*node.getTrueExpression());
482         ASSERT(thenResult.single());
483 
484         pushStacks();
485         PostResult elseResult = rebuild(*node.getFalseExpression());
486         ASSERT(elseResult.single());
487 
488         return {node, VisitBits::Post};
489     }
490 
visitTernaryPost(TIntermTernary & node)491     PostResult visitTernaryPost(TIntermTernary &node) override
492     {
493         TIntermTyped *cond  = node.getCondition();
494         TIntermTyped *then  = node.getTrueExpression();
495         TIntermTyped *else_ = node.getFalseExpression();
496 
497         auto *var               = CreateTempVariable(&mSymbolTable, &node.getType(), EvqTemporary);
498         TIntermTyped *newElse   = pullMappedExpr(else_, false);
499         TIntermBlock *elseBlock = &buildBlockWithTailAssign(*var, *newElse);
500         TIntermTyped *newThen   = pullMappedExpr(then, true);
501         TIntermBlock *thenBlock = &buildBlockWithTailAssign(*var, *newThen);
502         TIntermTyped *newCond   = pullMappedExpr(cond, true);
503 
504         pushStmt(*new TIntermDeclaration{var});
505         pushStmt(*new TIntermIfElse(newCond, thenBlock, elseBlock));
506         if (!isStandaloneExpr(node))
507         {
508             mExprMap[&node] = new TIntermSymbol(var);
509         }
510 
511         return node;
512     }
513 
visitSwizzlePost(TIntermSwizzle & node)514     PostResult visitSwizzlePost(TIntermSwizzle &node) override
515     {
516         TIntermTyped *expr    = node.getOperand();
517         TIntermTyped *newExpr = pullMappedExpr(expr, false);
518         if (expr == newExpr)
519         {
520             pushBinding(node, node);
521         }
522         else
523         {
524             pushBinding(node, *new TIntermSwizzle(newExpr, node.getSwizzleOffsets()));
525         }
526         return node;
527     }
528 
visitAggregatePost(TIntermAggregate & node)529     PostResult visitAggregatePost(TIntermAggregate &node) override
530     {
531         TIntermSequence &args = *node.getSequence();
532         for (TIntermNode *&arg : args)
533         {
534             TIntermTyped *targ = arg->getAsTyped();
535             ASSERT(targ);
536             arg = pullMappedExpr(targ, false);
537         }
538         pushBinding(node, node);
539         return node;
540     }
541 
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)542     PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node) override
543     {
544         pushStmt(node);
545         return node;
546     }
547 
visitFunctionPrototypePost(TIntermFunctionPrototype & node)548     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node) override
549     {
550         if (!getParentFunction())
551         {
552             pushStmt(node);
553         }
554         return node;
555     }
556 
visitCasePre(TIntermCase & node)557     PreResult visitCasePre(TIntermCase &node) override
558     {
559         if (TIntermTyped *cond = node.getCondition())
560         {
561             ASSERT(isTerminalExpr(*cond));
562         }
563         pushStmt(node);
564         return {node, VisitBits::Neither};
565     }
566 
visitSwitchPost(TIntermSwitch & node)567     PostResult visitSwitchPost(TIntermSwitch &node) override
568     {
569         TIntermTyped *init    = node.getInit();
570         TIntermTyped *newInit = pullMappedExpr(init, false);
571         if (init == newInit)
572         {
573             pushStmt(node);
574         }
575         else
576         {
577             pushStmt(*new TIntermSwitch(newInit, node.getStatementList()));
578         }
579 
580         return node;
581     }
582 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)583     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
584     {
585         pushStmt(node);
586         return node;
587     }
588 
visitIfElsePost(TIntermIfElse & node)589     PostResult visitIfElsePost(TIntermIfElse &node) override
590     {
591         TIntermTyped *cond    = node.getCondition();
592         TIntermTyped *newCond = pullMappedExpr(cond, false);
593         if (cond == newCond)
594         {
595             pushStmt(node);
596         }
597         else
598         {
599             pushStmt(*new TIntermIfElse(newCond, node.getTrueBlock(), node.getFalseBlock()));
600         }
601         return node;
602     }
603 
visitBranchPost(TIntermBranch & node)604     PostResult visitBranchPost(TIntermBranch &node) override
605     {
606         TIntermTyped *expr    = node.getExpression();
607         TIntermTyped *newExpr = pullMappedExpr(expr, false);
608         if (expr == newExpr)
609         {
610             pushStmt(node);
611         }
612         else
613         {
614             pushStmt(*new TIntermBranch(node.getFlowOp(), newExpr));
615         }
616         return node;
617     }
618 
visitLoopPre(TIntermLoop & node)619     PreResult visitLoopPre(TIntermLoop &node) override
620     {
621         if (!rebuildInPlace(*node.getBody()))
622         {
623             UNREACHABLE();
624         }
625         pushStmt(node);
626         return {node, VisitBits::Neither};
627     }
628 
visitConstantUnionPost(TIntermConstantUnion & node)629     PostResult visitConstantUnionPost(TIntermConstantUnion &node) override
630     {
631         const TType &type = node.getType();
632         if (!type.isScalar() && !type.isVector() && !type.isMatrix())
633         {
634             pushBinding(node, node);
635         }
636         return node;
637     }
638 
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)639     PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node) override
640     {
641         // With the removal of RewriteGlobalQualifierDecls, we may encounter globals while
642         // seperating compound expressions.
643         pushStmt(node);
644         return node;
645     }
646 };
647 
648 }  // anonymous namespace
649 
650 ////////////////////////////////////////////////////////////////////////////////
651 
SeparateCompoundExpressions(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,TIntermBlock & root)652 bool sh::SeparateCompoundExpressions(TCompiler &compiler,
653                                      SymbolEnv &symbolEnv,
654                                      IdGen &idGen,
655                                      TIntermBlock &root)
656 {
657     if (angle::GetBoolEnvironmentVar("GMT_DISABLE_SEPARATE_COMPOUND_EXPRESSIONS"))
658     {
659         return true;
660     }
661 
662     if (!PrePass(compiler).rebuildRoot(root))
663     {
664         return false;
665     }
666 
667     if (!Separator(compiler, symbolEnv, idGen).rebuildRoot(root))
668     {
669         return false;
670     }
671 
672     return true;
673 }
674