xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/RewriteArrayOfArrayOfOpaqueUniforms.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteArrayOfArrayOfOpaqueUniforms.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17 
18 namespace sh
19 {
20 namespace
21 {
22 struct UniformData
23 {
24     // Corresponding to an array of array of opaque uniform variable, this is the flattened variable
25     // that is replacing it.
26     const TVariable *flattened;
27     // Assume a general case of array declaration with N dimensions:
28     //
29     //     uniform type u[Dn]..[D2][D1];
30     //
31     // Let's define
32     //
33     //     Pn = D(n-1)*...*D2*D1
34     //
35     // In that case, we have:
36     //
37     //     u[In]         = ac + In*Pn
38     //     u[In][I(n-1)] = ac + In*Pn + I(n-1)*P(n-1)
39     //     u[In]...[Ii]  = ac + In*Pn + ... + Ii*Pi
40     //
41     // This array contains Pi.  Note that the like TType::mArraySizes, the last element is the
42     // outermost dimension.  Element 0 is necessarily 1.
43     TVector<unsigned int> mSubArraySizes;
44 };
45 
46 using UniformMap = angle::HashMap<const TVariable *, UniformData>;
47 
48 TIntermTyped *RewriteArrayOfArraySubscriptExpression(TCompiler *compiler,
49                                                      TIntermBinary *node,
50                                                      const UniformMap &uniformMap);
51 
52 // Given an expression, this traverser calculates a new expression where array of array of opaque
53 // uniforms are replaced with their flattened ones.  In particular, this is run on the right node of
54 // EOpIndexIndirect binary nodes, so that the expression in the index gets a chance to go through
55 // this transformation.
56 class RewriteExpressionTraverser final : public TIntermTraverser
57 {
58   public:
RewriteExpressionTraverser(TCompiler * compiler,const UniformMap & uniformMap)59     explicit RewriteExpressionTraverser(TCompiler *compiler, const UniformMap &uniformMap)
60         : TIntermTraverser(true, false, false), mCompiler(compiler), mUniformMap(uniformMap)
61     {}
62 
visitBinary(Visit visit,TIntermBinary * node)63     bool visitBinary(Visit visit, TIntermBinary *node) override
64     {
65         TIntermTyped *rewritten =
66             RewriteArrayOfArraySubscriptExpression(mCompiler, node, mUniformMap);
67         if (rewritten == nullptr)
68         {
69             return true;
70         }
71 
72         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
73 
74         // Don't iterate as the expression is rewritten.
75         return false;
76     }
77 
visitSymbol(TIntermSymbol * node)78     void visitSymbol(TIntermSymbol *node) override
79     {
80         // We cannot reach here for an opaque uniform that is being replaced.  visitBinary should
81         // have taken care of it.
82         ASSERT(!IsOpaqueType(node->getType().getBasicType()) ||
83                mUniformMap.find(&node->variable()) == mUniformMap.end());
84     }
85 
86   private:
87     TCompiler *mCompiler;
88 
89     const UniformMap &mUniformMap;
90 };
91 
92 // Rewrite the index of an EOpIndexIndirect expression.  The root can never need replacing, because
93 // it cannot be an opaque uniform itself.
RewriteIndexExpression(TCompiler * compiler,TIntermTyped * expression,const UniformMap & uniformMap)94 void RewriteIndexExpression(TCompiler *compiler,
95                             TIntermTyped *expression,
96                             const UniformMap &uniformMap)
97 {
98     RewriteExpressionTraverser traverser(compiler, uniformMap);
99     expression->traverse(&traverser);
100     bool valid = traverser.updateTree(compiler, expression);
101     ASSERT(valid);
102 }
103 
104 // Given an expression such as the following:
105 //
106 //                                              EOpIndex(In)Direct (opaque uniform)
107 //                                                    /           \
108 //                                            EOpIndex(In)Direct   I1
109 //                                                  /           \
110 //                                                ...            I2
111 //                                            /
112 //                                    EOpIndex(In)Direct
113 //                                          /           \
114 //                                      uniform          In
115 //
116 // produces:
117 //
118 //          EOpIndex(In)Direct
119 //            /        \
120 //        uniform    In*Pn + ... + I2*P2 + I1*P1
121 //
RewriteArrayOfArraySubscriptExpression(TCompiler * compiler,TIntermBinary * node,const UniformMap & uniformMap)122 TIntermTyped *RewriteArrayOfArraySubscriptExpression(TCompiler *compiler,
123                                                      TIntermBinary *node,
124                                                      const UniformMap &uniformMap)
125 {
126     // Only interested in opaque uniforms.
127     if (!IsOpaqueType(node->getType().getBasicType()))
128     {
129         return nullptr;
130     }
131 
132     TIntermSymbol *opaqueUniform = nullptr;
133 
134     // Iterate once and find the opaque uniform that's being indexed.
135     TIntermBinary *iter = node;
136     while (opaqueUniform == nullptr)
137     {
138         ASSERT(iter->getOp() == EOpIndexDirect || iter->getOp() == EOpIndexIndirect);
139 
140         opaqueUniform = iter->getLeft()->getAsSymbolNode();
141         iter          = iter->getLeft()->getAsBinaryNode();
142     }
143 
144     // If not being replaced, there's nothing to do.
145     auto flattenedIter = uniformMap.find(&opaqueUniform->variable());
146     if (flattenedIter == uniformMap.end())
147     {
148         return nullptr;
149     }
150 
151     const UniformData &data = flattenedIter->second;
152 
153     // Iterate again and build the index expression.  The index expression constitutes the sum of
154     // the variable indices plus a constant offset calculated from the constant indices.  For
155     // example, smplr[1][x][2][y] will have an index of x*P3 + y*P1 + c, where c = (1*P4 + 2*P2).
156     unsigned int constantOffset = 0;
157     TIntermTyped *variableIndex = nullptr;
158 
159     // Since the opaque uniforms are fully subscripted, we know exactly how many EOpIndex* nodes
160     // there should be.
161     for (size_t dimIndex = 0; dimIndex < data.mSubArraySizes.size(); ++dimIndex)
162     {
163         ASSERT(node);
164 
165         unsigned int subArraySize = data.mSubArraySizes[dimIndex];
166 
167         switch (node->getOp())
168         {
169             case EOpIndexDirect:
170                 // Accumulate the constant index.
171                 constantOffset +=
172                     node->getRight()->getAsConstantUnion()->getIConst(0) * subArraySize;
173                 break;
174             case EOpIndexIndirect:
175             {
176                 // Run RewriteExpressionTraverser on the right node.  It may itself be an expression
177                 // with an array of array of opaque uniform inside that needs to be rewritten.
178                 TIntermTyped *indexExpression = node->getRight();
179                 RewriteIndexExpression(compiler, indexExpression, uniformMap);
180 
181                 // Scale and accumulate.
182                 if (subArraySize != 1)
183                 {
184                     indexExpression =
185                         new TIntermBinary(EOpMul, indexExpression, CreateIndexNode(subArraySize));
186                 }
187 
188                 if (variableIndex == nullptr)
189                 {
190                     variableIndex = indexExpression;
191                 }
192                 else
193                 {
194                     variableIndex = new TIntermBinary(EOpAdd, variableIndex, indexExpression);
195                 }
196                 break;
197             }
198             default:
199                 UNREACHABLE();
200                 break;
201         }
202 
203         node = node->getLeft()->getAsBinaryNode();
204     }
205 
206     // Add the two accumulated indices together.
207     TIntermTyped *index = nullptr;
208     if (constantOffset == 0 && variableIndex != nullptr)
209     {
210         // No constant offset, but there's variable offset.  Take that as offset.
211         index = variableIndex;
212     }
213     else
214     {
215         // Either the constant offset is non zero, or there's no variable offset (so constant 0
216         // should be used).
217         index = CreateIndexNode(constantOffset);
218 
219         if (variableIndex)
220         {
221             index = new TIntermBinary(EOpAdd, index, variableIndex);
222         }
223     }
224 
225     // Create an index into the flattened uniform.
226     TOperator op = variableIndex ? EOpIndexIndirect : EOpIndexDirect;
227     return new TIntermBinary(op, new TIntermSymbol(data.flattened), index);
228 }
229 
230 // Traverser that takes:
231 //
232 //     uniform sampler/image/atomic_uint u[N][M]..
233 //
234 // and transforms it to:
235 //
236 //     uniform sampler/image/atomic_uint u[N * M * ..]
237 //
238 // MonomorphizeUnsupportedFunctions makes it impossible for this array to be partially
239 // subscripted, or passed as argument to a function unsubscripted.  This means that every encounter
240 // of this uniform can be expected to be fully subscripted.
241 //
242 class RewriteArrayOfArrayOfOpaqueUniformsTraverser : public TIntermTraverser
243 {
244   public:
RewriteArrayOfArrayOfOpaqueUniformsTraverser(TCompiler * compiler,TSymbolTable * symbolTable)245     RewriteArrayOfArrayOfOpaqueUniformsTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
246         : TIntermTraverser(true, false, false, symbolTable), mCompiler(compiler)
247     {}
248 
visitDeclaration(Visit visit,TIntermDeclaration * node)249     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
250     {
251         if (!mInGlobalScope)
252         {
253             return true;
254         }
255 
256         const TIntermSequence &sequence = *(node->getSequence());
257 
258         TIntermTyped *variable = sequence.front()->getAsTyped();
259         const TType &type      = variable->getType();
260         bool isOpaqueUniform =
261             type.getQualifier() == EvqUniform && IsOpaqueType(type.getBasicType());
262 
263         // Only interested in array of array of opaque uniforms.
264         if (!isOpaqueUniform || !type.isArrayOfArrays())
265         {
266             return false;
267         }
268 
269         // Opaque uniforms cannot have initializers, so the declaration must necessarily be a
270         // symbol.
271         TIntermSymbol *symbol = variable->getAsSymbolNode();
272         ASSERT(symbol != nullptr);
273 
274         const TVariable *uniformVariable = &symbol->variable();
275 
276         // Create an entry in the map.
277         ASSERT(mUniformMap.find(uniformVariable) == mUniformMap.end());
278         UniformData &data = mUniformMap[uniformVariable];
279 
280         // Calculate the accumulated dimension products.  See UniformData::mSubArraySizes.
281         const TSpan<const unsigned int> &arraySizes = type.getArraySizes();
282         mUniformMap[uniformVariable].mSubArraySizes.resize(arraySizes.size());
283         unsigned int runningProduct = 1;
284         for (size_t dimension = 0; dimension < arraySizes.size(); ++dimension)
285         {
286             data.mSubArraySizes[dimension] = runningProduct;
287             runningProduct *= arraySizes[dimension];
288         }
289 
290         // Create a replacement variable with the array flattened.
291         TType *newType = new TType(type);
292         newType->toArrayBaseType();
293         newType->makeArray(runningProduct);
294 
295         data.flattened = new TVariable(mSymbolTable, uniformVariable->name(), newType,
296                                        uniformVariable->symbolType());
297 
298         TIntermDeclaration *decl = new TIntermDeclaration;
299         decl->appendDeclarator(new TIntermSymbol(data.flattened));
300 
301         queueReplacement(decl, OriginalNode::IS_DROPPED);
302         return false;
303     }
304 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)305     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
306     {
307         // As an optimization, don't bother inspecting functions if there aren't any opaque uniforms
308         // to replace.
309         return !mUniformMap.empty();
310     }
311 
312     // Same implementation as in RewriteExpressionTraverser.  That traverser cannot replace root.
visitBinary(Visit visit,TIntermBinary * node)313     bool visitBinary(Visit visit, TIntermBinary *node) override
314     {
315         TIntermTyped *rewritten =
316             RewriteArrayOfArraySubscriptExpression(mCompiler, node, mUniformMap);
317         if (rewritten == nullptr)
318         {
319             return true;
320         }
321 
322         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
323 
324         // Don't iterate as the expression is rewritten.
325         return false;
326     }
327 
visitSymbol(TIntermSymbol * node)328     void visitSymbol(TIntermSymbol *node) override
329     {
330         ASSERT(!IsOpaqueType(node->getType().getBasicType()) ||
331                mUniformMap.find(&node->variable()) == mUniformMap.end());
332     }
333 
334   private:
335     TCompiler *mCompiler;
336     UniformMap mUniformMap;
337 };
338 }  // anonymous namespace
339 
RewriteArrayOfArrayOfOpaqueUniforms(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)340 bool RewriteArrayOfArrayOfOpaqueUniforms(TCompiler *compiler,
341                                          TIntermBlock *root,
342                                          TSymbolTable *symbolTable)
343 {
344     RewriteArrayOfArrayOfOpaqueUniformsTraverser traverser(compiler, symbolTable);
345     root->traverse(&traverser);
346     return traverser.updateTree(compiler, root);
347 }
348 }  // namespace sh
349