xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/SeparateDeclarations.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 #include "compiler/translator/tree_ops/SeparateDeclarations.h"
7 
8 #include "common/hash_containers.h"
9 #include "compiler/translator/IntermRebuild.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/util.h"
12 
13 namespace sh
14 {
15 namespace
16 {
17 
18 class Separator final : private TIntermRebuild
19 {
20   public:
Separator(TCompiler & compiler,bool separateCompoundStructDeclarations)21     Separator(TCompiler &compiler, bool separateCompoundStructDeclarations)
22         : TIntermRebuild(compiler, true, true),
23           mSeparateCompoundStructDeclarations(separateCompoundStructDeclarations)
24     {}
25     using TIntermRebuild::rebuildRoot;
26 
27   private:
recordModifiedStructVariables(TIntermDeclaration & node)28     void recordModifiedStructVariables(TIntermDeclaration &node)
29     {
30         ASSERT(!mNewStructure);  // No nested struct declarations.
31         TIntermSequence &sequence = *node.getSequence();
32         if (sequence.size() <= 1 && !mSeparateCompoundStructDeclarations)
33         {
34             return;
35         }
36         TIntermTyped *declarator    = sequence.at(0)->getAsTyped();
37         const TType &declaratorType = declarator->getType();
38         const TStructure *structure = declaratorType.getStruct();
39         // Rewrite variable declarations that specify structs AND variable(s) at the same
40         // time.
41         if (!structure || !declaratorType.isStructSpecifier())
42         {
43             return;
44         }
45         if (mSeparateCompoundStructDeclarations && sequence.size() == 1)
46         {
47             if (TIntermSymbol *symbol = declarator->getAsSymbolNode(); symbol != nullptr)
48             {
49                 if (symbol->variable().symbolType() == SymbolType::Empty)
50                 {
51                     return;
52                 }
53             }
54         }
55         // By default, struct specifier changes for all variables except the first one.
56         uint32_t index = 1;
57         if (structure->symbolType() == SymbolType::Empty)
58         {
59             TStructure *newStructure =
60                 new TStructure(&mSymbolTable, kEmptyImmutableString, &structure->fields(),
61                                SymbolType::AngleInternal);
62             newStructure->setAtGlobalScope(structure->atGlobalScope());
63             structure     = newStructure;
64             // Adding name causes the struct type to change, so also the first variable variable
65             // needs rewriting.
66             index = 0;
67         }
68         if (mSeparateCompoundStructDeclarations)
69         {
70             mNewStructure = structure;
71             // Separating struct and variable declaration causes the variable type to change
72             // from specifying to not-specifying, so also the first variable needs rewriting.
73             index = 0;
74         }
75 
76         for (; index < sequence.size(); ++index)
77         {
78             Declaration decl              = ViewDeclaration(node, index);
79             const TVariable &var          = decl.symbol.variable();
80             const TType &varType          = var.getType();
81             const bool newTypeIsSpecifier = index == 0 && !mSeparateCompoundStructDeclarations;
82             TType *newType                = new TType(structure, newTypeIsSpecifier);
83             newType->setQualifier(varType.getQualifier());
84             newType->makeArrays(varType.getArraySizes());
85             TVariable *newVar = new TVariable(&mSymbolTable, var.name(), newType, var.symbolType());
86             mStructVariables.insert(std::make_pair(&var, newVar));
87         }
88     }
89 
visitDeclarationPre(TIntermDeclaration & node)90     PreResult visitDeclarationPre(TIntermDeclaration &node) override
91     {
92         recordModifiedStructVariables(node);
93         return node;
94     }
95 
visitDeclarationPost(TIntermDeclaration & node)96     PostResult visitDeclarationPost(TIntermDeclaration &node) override
97     {
98         TIntermSequence &sequence = *node.getSequence();
99         if (sequence.size() <= 1 && !mNewStructure)
100         {
101             return node;
102         }
103         std::vector<TIntermNode *> replacements;
104         if (mNewStructure)
105         {
106             TType *newType = new TType(mNewStructure, true);
107             if (mNewStructure->atGlobalScope())
108             {
109                 newType->setQualifier(EvqGlobal);
110             }
111             TVariable *structVar =
112                 new TVariable(&mSymbolTable, kEmptyImmutableString, newType, SymbolType::Empty);
113             TIntermDeclaration *replacement = new TIntermDeclaration({structVar});
114             replacement->setLine(node.getLine());
115             replacements.push_back(replacement);
116             mNewStructure = nullptr;
117         }
118         for (uint32_t index = 0; index < sequence.size(); ++index)
119         {
120             TIntermTyped *declarator        = sequence.at(index)->getAsTyped();
121             TIntermDeclaration *replacement = new TIntermDeclaration({declarator});
122             replacement->setLine(declarator->getLine());
123             replacements.push_back(replacement);
124         }
125         return PostResult::Multi(std::move(replacements));
126     }
127 
visitSymbolPre(TIntermSymbol & symbolNode)128     PreResult visitSymbolPre(TIntermSymbol &symbolNode) override
129     {
130         auto it = mStructVariables.find(&symbolNode.variable());
131         if (it == mStructVariables.end())
132         {
133             return symbolNode;
134         }
135         return *new TIntermSymbol(it->second);
136     }
137 
visitFunctionPrototypePre(TIntermFunctionPrototype & node)138     PreResult visitFunctionPrototypePre(TIntermFunctionPrototype &node) override
139     {
140         const TFunction *function = node.getFunction();
141         auto it                   = mFunctionsToReplace.find(function);
142         if (it != mFunctionsToReplace.end())
143         {
144             TIntermFunctionPrototype *newFuncProto = new TIntermFunctionPrototype(it->second);
145             return newFuncProto;
146         }
147         else if (node.getType().isStructSpecifier())
148         {
149             const TType &oldType        = node.getType();
150             const TStructure *structure = oldType.getStruct();
151             // Name unnamed inline structs
152             if (structure->symbolType() == SymbolType::Empty)
153             {
154                 TStructure *newStructure =
155                     new TStructure(&mSymbolTable, kEmptyImmutableString, &structure->fields(),
156                                    SymbolType::AngleInternal);
157                 newStructure->setAtGlobalScope(structure->atGlobalScope());
158                 structure = newStructure;
159             }
160             TType *newType = new TType(structure, true);
161             if (structure->atGlobalScope())
162             {
163                 newType->setQualifier(EvqGlobal);
164             }
165             TVariable *structVar =
166                 new TVariable(&mSymbolTable, ImmutableString(""), newType, SymbolType::Empty);
167             TType *returnType = new TType(structure, false);
168             if (oldType.isArray())
169             {
170                 returnType->makeArrays(oldType.getArraySizes());
171             }
172             returnType->setQualifier(oldType.getQualifier());
173 
174             const TFunction *oldFunc = function;
175             ASSERT(oldFunc->symbolType() == SymbolType::UserDefined);
176 
177             const TFunction *newFunc     = cloneFunctionAndChangeReturnType(oldFunc, returnType);
178             mFunctionsToReplace[oldFunc] = newFunc;
179             if (getParentNode()->getAsFunctionDefinition() != nullptr)
180             {
181                 mNewFunctionReturnStructDeclaration = new TIntermDeclaration({structVar});
182                 return new TIntermFunctionPrototype(newFunc);
183             }
184             return PreResult::Multi(
185                 {new TIntermDeclaration({structVar}), new TIntermFunctionPrototype(newFunc)});
186         }
187 
188         return node;
189     }
190 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)191     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node) override
192     {
193         if (mNewFunctionReturnStructDeclaration)
194         {
195             return PostResult::Multi(
196                 {std::exchange(mNewFunctionReturnStructDeclaration, nullptr), &node});
197         }
198         return node;
199     }
200 
visitAggregatePre(TIntermAggregate & node)201     PreResult visitAggregatePre(TIntermAggregate &node) override
202     {
203         const TFunction *function = node.getFunction();
204         auto it                   = mFunctionsToReplace.find(function);
205         if (it != mFunctionsToReplace.end())
206         {
207             TIntermAggregate *replacementNode =
208                 TIntermAggregate::CreateFunctionCall(*it->second, node.getSequence());
209 
210             return PreResult(replacementNode, VisitBits::Children);
211         }
212 
213         return node;
214     }
215 
216   private:
cloneFunctionAndChangeReturnType(const TFunction * oldFunc,const TType * newReturnType)217     const TFunction *cloneFunctionAndChangeReturnType(const TFunction *oldFunc,
218                                                       const TType *newReturnType)
219 
220     {
221         ASSERT(oldFunc->symbolType() == SymbolType::UserDefined);
222 
223         TFunction *newFunc = new TFunction(&mSymbolTable, oldFunc->name(), oldFunc->symbolType(),
224                                            newReturnType, oldFunc->isKnownToNotHaveSideEffects());
225 
226         if (oldFunc->isDefined())
227         {
228             newFunc->setDefined();
229         }
230 
231         if (oldFunc->hasPrototypeDeclaration())
232         {
233             newFunc->setHasPrototypeDeclaration();
234         }
235 
236         const size_t paramCount = oldFunc->getParamCount();
237         for (size_t i = 0; i < paramCount; ++i)
238         {
239             const TVariable *var = oldFunc->getParam(i);
240             newFunc->addParameter(var);
241         }
242 
243         return newFunc;
244     }
245 
246     angle::HashMap<const TFunction *, const TFunction *> mFunctionsToReplace;
247     // New structure separated from function declaration.
248     TIntermDeclaration *mNewFunctionReturnStructDeclaration = nullptr;
249 
250     // New structure from compound declaration.
251     const TStructure *mNewStructure = nullptr;
252     // Old struct variable to new struct variable mapping.
253     angle::HashMap<const TVariable *, TVariable *> mStructVariables;
254     const bool mSeparateCompoundStructDeclarations;
255 };
256 
257 }  // namespace
258 
SeparateDeclarations(TCompiler & compiler,TIntermBlock & root,bool separateCompoundStructDeclarations)259 bool SeparateDeclarations(TCompiler &compiler,
260                           TIntermBlock &root,
261                           bool separateCompoundStructDeclarations)
262 {
263     Separator separator(compiler, separateCompoundStructDeclarations);
264     return separator.rebuildRoot(root);
265 }
266 
267 }  // namespace sh
268