xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with
7 // parameters that are incompatible with both Vulkan GLSL and Metal.
8 //
9 
10 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
11 
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17 
18 namespace sh
19 {
20 namespace
21 {
22 struct Argument
23 {
24     size_t argumentIndex;
25     TIntermTyped *argument;
26 };
27 
28 struct FunctionData
29 {
30     // Whether the original function is used.  If this is false, the function can be removed because
31     // all callers have been modified.
32     bool isOriginalUsed;
33     // The original definition of the function, used to create the monomorphized version.
34     TIntermFunctionDefinition *originalDefinition;
35     // List of monomorphized versions of this function.  They will be added next to the original
36     // version (or replace it).
37     TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
38 };
39 
40 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
41 
42 // Traverse the function definitions and initialize the map.  Allows visitAggregate to have access
43 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)44 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
45 {
46     TIntermSequence &sequence = *root->getSequence();
47 
48     for (TIntermNode *node : sequence)
49     {
50         TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
51         if (asFuncDef != nullptr)
52         {
53             const TFunction *function = asFuncDef->getFunction();
54             ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
55             (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
56         }
57     }
58 }
59 
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)60 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
61 {
62     *isSamplerInStructOut = false;
63 
64     while (node->getAsBinaryNode())
65     {
66         TIntermBinary *asBinary = node->getAsBinaryNode();
67 
68         TOperator op = asBinary->getOp();
69 
70         // No opaque uniform can be inside an interface block.
71         if (op == EOpIndexDirectInterfaceBlock)
72         {
73             return nullptr;
74         }
75 
76         if (op == EOpIndexDirectStruct)
77         {
78             *isSamplerInStructOut = true;
79         }
80 
81         node = asBinary->getLeft();
82     }
83 
84     // Only interested in uniform opaque types.  If a function call within another function uses
85     // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
86     // calling function is monomorphized.
87     if (node->getType().getQualifier() != EvqUniform)
88     {
89         return nullptr;
90     }
91 
92     ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
93            node->getType().isStructureContainingSamplers());
94 
95     TIntermSymbol *asSymbol = node->getAsSymbolNode();
96     ASSERT(asSymbol);
97 
98     return &asSymbol->variable();
99 }
100 
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)101 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
102                                  TIntermTyped *node,
103                                  TIntermSequence *replacementIndices)
104 {
105     TIntermTyped *withoutSideEffects = node->deepCopy();
106 
107     for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
108          asBinary                = asBinary->getLeft()->getAsBinaryNode())
109     {
110         TOperator op        = asBinary->getOp();
111         TIntermTyped *index = asBinary->getRight();
112 
113         if (op == EOpIndexDirectStruct)
114         {
115             break;
116         }
117 
118         // No side effects with constant expressions.
119         if (op == EOpIndexDirect)
120         {
121             ASSERT(index->getAsConstantUnion());
122             continue;
123         }
124 
125         ASSERT(op == EOpIndexIndirect);
126 
127         // If the index is a symbol, there's no side effect, so leave it as-is.
128         if (index->getAsSymbolNode())
129         {
130             continue;
131         }
132 
133         // Otherwise create a temp variable initialized with the index and use that temp variable as
134         // the index.
135         TIntermDeclaration *tempDecl = nullptr;
136         TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
137 
138         replacementIndices->push_back(tempDecl);
139         asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
140     }
141 
142     return withoutSideEffects;
143 }
144 
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)145 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
146                                          const TVector<Argument> &replacedArguments,
147                                          TIntermSequence *substituteArgsOut)
148 {
149     size_t nextReplacedArg = 0;
150     for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
151     {
152         if (nextReplacedArg >= replacedArguments.size() ||
153             argIndex != replacedArguments[nextReplacedArg].argumentIndex)
154         {
155             // Not replaced, keep argument as is.
156             substituteArgsOut->push_back(originalCallArguments[argIndex]);
157         }
158         else
159         {
160             TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
161 
162             // Iterate over indices of the argument and create a new arg for every non-const
163             // index.  Note that the index itself may be an expression, and it may require further
164             // substitution in the next pass.
165             while (argument->getAsBinaryNode())
166             {
167                 TIntermBinary *asBinary = argument->getAsBinaryNode();
168                 if (asBinary->getOp() == EOpIndexIndirect)
169                 {
170                     TIntermTyped *index = asBinary->getRight();
171                     substituteArgsOut->push_back(index->deepCopy());
172                 }
173                 argument = asBinary->getLeft();
174             }
175 
176             ++nextReplacedArg;
177         }
178     }
179 }
180 
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)181 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
182                                       const TFunction *original,
183                                       TVector<Argument> *replacedArguments,
184                                       VariableReplacementMap *argumentMapOut)
185 {
186     TFunction *substituteFunction =
187         new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
188                       &original->getReturnType(), original->isKnownToNotHaveSideEffects());
189 
190     size_t nextReplacedArg = 0;
191     for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
192     {
193         const TVariable *originalParam = original->getParam(paramIndex);
194 
195         if (nextReplacedArg >= replacedArguments->size() ||
196             paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
197         {
198             TVariable *substituteArgument =
199                 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
200                               originalParam->symbolType());
201             // Not replaced, add an identical parameter.
202             substituteFunction->addParameter(substituteArgument);
203             (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
204         }
205         else
206         {
207             TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
208             (*argumentMapOut)[originalParam] = substituteArgument;
209 
210             // Iterate over indices of the argument and create a new parameter for every non-const
211             // index (which may be an expression).  Replace the symbol in the argument with a
212             // variable of the index type.  This is later used to replace the parameter in the
213             // function body.
214             while (substituteArgument->getAsBinaryNode())
215             {
216                 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
217                 if (asBinary->getOp() == EOpIndexIndirect)
218                 {
219                     TIntermTyped *index = asBinary->getRight();
220                     TType *indexType    = new TType(index->getType());
221                     indexType->setQualifier(EvqParamIn);
222 
223                     TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
224                                                      SymbolType::AngleInternal);
225                     substituteFunction->addParameter(param);
226 
227                     // The argument now uses the function parameters as indices.
228                     asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
229                 }
230                 substituteArgument = asBinary->getLeft();
231             }
232 
233             ++nextReplacedArg;
234         }
235     }
236 
237     return substituteFunction;
238 }
239 
240 class MonomorphizeTraverser final : public TIntermTraverser
241 {
242   public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,FunctionMap * functionMap)243     explicit MonomorphizeTraverser(TCompiler *compiler,
244                                    TSymbolTable *symbolTable,
245                                    UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,
246                                    FunctionMap *functionMap)
247         : TIntermTraverser(true, false, false, symbolTable),
248           mCompiler(compiler),
249           mUnsupportedFunctionArgs(unsupportedFunctionArgs),
250           mFunctionMap(functionMap)
251     {}
252 
visitAggregate(Visit visit,TIntermAggregate * node)253     bool visitAggregate(Visit visit, TIntermAggregate *node) override
254     {
255         if (node->getOp() != EOpCallFunctionInAST)
256         {
257             return true;
258         }
259 
260         const TFunction *function = node->getFunction();
261         ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
262 
263         FunctionData &data = (*mFunctionMap)[function];
264 
265         TIntermFunctionDefinition *monomorphized =
266             processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
267         if (monomorphized)
268         {
269             data.monomorphizedDefinitions.push_back(monomorphized);
270         }
271 
272         return true;
273     }
274 
getAnyMonomorphized() const275     bool getAnyMonomorphized() const { return mAnyMonomorphized; }
276 
277   private:
isUnsupportedArgument(TIntermTyped * callArgument,const TVariable * funcArgument) const278     bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const
279     {
280         // Only interested in opaque uniforms and structs that contain samplers.
281         const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
282         const bool isStructContainingSamplers =
283             funcArgument->getType().isStructureContainingSamplers();
284         if (!isOpaqueType && !isStructContainingSamplers)
285         {
286             return false;
287         }
288 
289         // If not uniform (the variable was itself a function parameter), don't process it in
290         // this pass, as we don't know which actual uniform it corresponds to.
291         bool isSamplerInStruct   = false;
292         const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
293         if (uniform == nullptr)
294         {
295             return false;
296         }
297 
298         const TType &type = uniform->getType();
299 
300         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers])
301         {
302             // Monomorphize if the parameter is a structure that contains samplers (so in
303             // RewriteStructSamplers we don't need to rewrite the functions to accept multiple
304             // parameters split from the struct).
305             if (isStructContainingSamplers)
306             {
307                 return true;
308             }
309         }
310 
311         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage])
312         {
313             // Monomorphize if:
314             //
315             // - The opaque uniform is a sampler in a struct (which can create an array-of-array
316             //   situation), and the function expects an array of samplers, or
317             //
318             // - The opaque uniform is an array of array of sampler or image, and it's partially
319             //   subscripted (i.e. the function itself expects an array)
320             //
321             const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
322             const bool isArrayOfArrayOfSamplerOrImage =
323                 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
324             if (isSamplerInStruct && isParameterArrayOfOpaqueType)
325             {
326                 return true;
327             }
328             if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType)
329             {
330                 return true;
331             }
332         }
333 
334         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter])
335         {
336             if (type.isAtomicCounter())
337             {
338                 return true;
339             }
340         }
341 
342         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image])
343         {
344             if (type.isImage())
345             {
346                 return true;
347             }
348         }
349 
350         if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage])
351         {
352             if (type.isPixelLocal())
353             {
354                 return true;
355             }
356         }
357 
358         return false;
359     }
360 
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)361     TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
362                                                    TIntermFunctionDefinition *originalDefinition,
363                                                    bool *isOriginalUsedOut)
364     {
365         const TFunction *function            = functionCall->getFunction();
366         const TIntermSequence &callArguments = *functionCall->getSequence();
367 
368         TVector<Argument> replacedArguments;
369         TIntermSequence replacementIndices;
370 
371         // Go through function call arguments, and see if any is used in an unsupported way.
372         for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
373         {
374             TIntermTyped *callArgument    = callArguments[argIndex]->getAsTyped();
375             const TVariable *funcArgument = function->getParam(argIndex);
376             if (isUnsupportedArgument(callArgument, funcArgument))
377             {
378                 // Copy the argument and extract the side effects.
379                 TIntermTyped *argument =
380                     ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
381 
382                 replacedArguments.push_back({argIndex, argument});
383             }
384         }
385 
386         if (replacedArguments.empty())
387         {
388             *isOriginalUsedOut = true;
389             return nullptr;
390         }
391 
392         mAnyMonomorphized = true;
393 
394         insertStatementsInParentBlock(replacementIndices);
395 
396         // Create the arguments for the substitute function call.  Done before monomorphizing the
397         // function, which transforms the arguments to what needs to be replaced in the function
398         // body.
399         TIntermSequence newCallArgs;
400         CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
401 
402         // Duplicate the function and substitute the replaced arguments with only the non-const
403         // indices.  Additionally, substitute the non-const indices of arguments with the new
404         // function parameters.
405         VariableReplacementMap argumentMap;
406         const TFunction *monomorphized =
407             MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
408 
409         // Replace this function call with a call to the new one.
410         queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
411                          OriginalNode::IS_DROPPED);
412 
413         // Create a new function definition, with the body of the old function but with the replaced
414         // parameters substituted with the calling expressions.
415         TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
416         TIntermBlock *substituteBlock                 = originalDefinition->getBody()->deepCopy();
417         GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
418         bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
419         ASSERT(valid);
420 
421         return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
422     }
423 
424     TCompiler *mCompiler;
425     UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs;
426     bool mAnyMonomorphized = false;
427 
428     // Map of original to monomorphized functions.
429     FunctionMap *mFunctionMap;
430 };
431 
432 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
433 {
434   public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)435     explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
436                                                  const FunctionMap &functionMap)
437         : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
438     {}
439 
visitFunctionPrototype(TIntermFunctionPrototype * node)440     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
441     {
442         const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
443         if (isInFunctionDefinition)
444         {
445             return;
446         }
447 
448         // Add to and possibly replace the function prototype with replacement prototypes.
449         const TFunction *function = node->getFunction();
450         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
451 
452         const FunctionData &data = mFunctionMap.at(function);
453 
454         // If nothing to do, leave it be.
455         if (data.monomorphizedDefinitions.empty())
456         {
457             ASSERT(data.isOriginalUsed || function->isMain());
458             return;
459         }
460 
461         // Replace the prototype with itself (if function is still used) as well as any
462         // monomorphized versions.
463         TIntermSequence replacement;
464         if (data.isOriginalUsed)
465         {
466             replacement.push_back(node);
467         }
468         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
469         {
470             replacement.push_back(new TIntermFunctionPrototype(
471                 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
472         }
473         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
474                                         std::move(replacement));
475     }
476 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)477     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
478     {
479         // Add to and possibly replace the function definition with replacement definitions.
480         const TFunction *function = node->getFunction();
481         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
482 
483         const FunctionData &data = mFunctionMap.at(function);
484 
485         // If nothing to do, leave it be.
486         if (data.monomorphizedDefinitions.empty())
487         {
488             ASSERT(data.isOriginalUsed || function->isMain());
489             return false;
490         }
491 
492         // Replace the definition with itself (if function is still used) as well as any
493         // monomorphized versions.
494         TIntermSequence replacement;
495         if (data.isOriginalUsed)
496         {
497             replacement.push_back(node);
498         }
499         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
500         {
501             replacement.push_back(monomorphizedDefinition);
502         }
503         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
504                                         std::move(replacement));
505 
506         return false;
507     }
508 
509   private:
510     const FunctionMap &mFunctionMap;
511 };
512 
SortDeclarations(TIntermBlock * root)513 void SortDeclarations(TIntermBlock *root)
514 {
515     TIntermSequence *original = root->getSequence();
516 
517     TIntermSequence replacement;
518     TIntermSequence functionDefs;
519 
520     // Accumulate non-function-definition declarations in |replacement| and function definitions in
521     // |functionDefs|.
522     for (TIntermNode *node : *original)
523     {
524         if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
525         {
526             functionDefs.push_back(node);
527         }
528         else
529         {
530             replacement.push_back(node);
531         }
532     }
533 
534     // Append function definitions to |replacement|.
535     replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
536 
537     // Replace root's sequence with |replacement|.
538     root->replaceAllChildren(replacement);
539 }
540 
MonomorphizeUnsupportedFunctionsImpl(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)541 bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler,
542                                           TIntermBlock *root,
543                                           TSymbolTable *symbolTable,
544                                           UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
545 {
546     // First, sort out the declarations such that all non-function declarations are placed before
547     // function definitions.  This way when the function is replaced with one that references said
548     // declarations (i.e. uniforms), the uniform declaration is already present above it.
549     SortDeclarations(root);
550 
551     while (true)
552     {
553         FunctionMap functionMap;
554         InitializeFunctionMap(root, &functionMap);
555 
556         MonomorphizeTraverser monomorphizer(compiler, symbolTable, unsupportedFunctionArgs,
557                                             &functionMap);
558         root->traverse(&monomorphizer);
559 
560         if (!monomorphizer.getAnyMonomorphized())
561         {
562             break;
563         }
564 
565         if (!monomorphizer.updateTree(compiler, root))
566         {
567             return false;
568         }
569 
570         UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
571         root->traverse(&functionUpdater);
572 
573         if (!functionUpdater.updateTree(compiler, root))
574         {
575             return false;
576         }
577     }
578 
579     return true;
580 }
581 }  // anonymous namespace
582 
MonomorphizeUnsupportedFunctions(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)583 bool MonomorphizeUnsupportedFunctions(TCompiler *compiler,
584                                       TIntermBlock *root,
585                                       TSymbolTable *symbolTable,
586                                       UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
587 {
588     // This function actually applies multiple transformation, and the AST may not be valid until
589     // the transformations are entirely done.  Some validation is momentarily disabled.
590     bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
591 
592     bool result =
593         MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, unsupportedFunctionArgs);
594 
595     compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
596     return result && compiler->validateAST(root);
597 }
598 }  // namespace sh
599