xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/DiscoverDependentFunctions.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 #include "compiler/translator/msl/DiscoverDependentFunctions.h"
12 #include "compiler/translator/msl/DiscoverEnclosingFunctionTraverser.h"
13 #include "compiler/translator/msl/MapFunctionsToDefinitions.h"
14 
15 using namespace sh;
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 namespace
20 {
21 
22 class Discoverer : public DiscoverEnclosingFunctionTraverser
23 {
24   private:
25     const std::function<bool(const TVariable &)> &mVars;
26     const FunctionToDefinition &mFuncToDef;
27     std::unordered_set<const TFunction *> mNonDepFunctions;
28 
29   public:
30     std::unordered_set<const TFunction *> mDepFunctions;
31 
32   public:
Discoverer(const std::function<bool (const TVariable &)> & vars,const FunctionToDefinition & funcToDef)33     Discoverer(const std::function<bool(const TVariable &)> &vars,
34                const FunctionToDefinition &funcToDef)
35         : DiscoverEnclosingFunctionTraverser(true, false, true), mVars(vars), mFuncToDef(funcToDef)
36     {}
37 
visitSymbol(TIntermSymbol * symbolNode)38     void visitSymbol(TIntermSymbol *symbolNode) override
39     {
40         const TVariable &var = symbolNode->variable();
41         if (!mVars(var))
42         {
43             return;
44         }
45         const TFunction *owner = discoverEnclosingFunction(symbolNode);
46         if (owner)
47         {
48             mDepFunctions.insert(owner);
49         }
50     }
51 
visitAggregate(Visit visit,TIntermAggregate * aggregateNode)52     bool visitAggregate(Visit visit, TIntermAggregate *aggregateNode) override
53     {
54         if (visit != Visit::PreVisit)
55         {
56             return true;
57         }
58 
59         if (!aggregateNode->isConstructor())
60         {
61             const TFunction *func = aggregateNode->getFunction();
62 
63             if (mNonDepFunctions.find(func) != mNonDepFunctions.end())
64             {
65                 return true;
66             }
67 
68             if (mDepFunctions.find(func) == mDepFunctions.end())
69             {
70                 auto it = mFuncToDef.find(func);
71                 if (it == mFuncToDef.end())
72                 {
73                     return true;
74                 }
75 
76                 // Recursion is banned in GLSL, so I believe AngleIR has this property too.
77                 // This implementation assumes (direct and mutual) recursion is prohibited.
78                 TIntermFunctionDefinition &funcDefNode = *it->second;
79                 funcDefNode.traverse(this);
80                 if (mNonDepFunctions.find(func) != mNonDepFunctions.end())
81                 {
82                     return true;
83                 }
84                 ASSERT(mDepFunctions.find(func) != mDepFunctions.end());
85             }
86 
87             const TFunction *owner = discoverEnclosingFunction(aggregateNode);
88             ASSERT(owner);
89             mDepFunctions.insert(owner);
90         }
91 
92         return true;
93     }
94 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * funcDefNode)95     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *funcDefNode) override
96     {
97         const TFunction *func = funcDefNode->getFunction();
98 
99         if (visit != Visit::PostVisit)
100         {
101             if (mDepFunctions.find(func) != mDepFunctions.end())
102             {
103                 return false;
104             }
105 
106             if (mNonDepFunctions.find(func) != mNonDepFunctions.end())
107             {
108                 return false;
109             }
110 
111             return true;
112         }
113 
114         if (mDepFunctions.find(func) == mDepFunctions.end())
115         {
116             mNonDepFunctions.insert(func);
117         }
118 
119         return true;
120     }
121 };
122 
123 }  // namespace
124 
DiscoverDependentFunctions(TIntermBlock & root,const std::function<bool (const TVariable &)> & vars)125 std::unordered_set<const TFunction *> sh::DiscoverDependentFunctions(
126     TIntermBlock &root,
127     const std::function<bool(const TVariable &)> &vars)
128 {
129     const FunctionToDefinition funcToDef = MapFunctionsToDefinitions(root);
130     Discoverer discoverer(vars, funcToDef);
131     root.traverse(&discoverer);
132     return std::move(discoverer.mDepFunctions);
133 }
134