1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with
7 // parameters that are incompatible with both Vulkan GLSL and Metal.
8 //
9
10 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
11
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "compiler/translator/tree_util/ReplaceVariable.h"
17
18 namespace sh
19 {
20 namespace
21 {
22 struct Argument
23 {
24 size_t argumentIndex;
25 TIntermTyped *argument;
26 };
27
28 struct FunctionData
29 {
30 // Whether the original function is used. If this is false, the function can be removed because
31 // all callers have been modified.
32 bool isOriginalUsed;
33 // The original definition of the function, used to create the monomorphized version.
34 TIntermFunctionDefinition *originalDefinition;
35 // List of monomorphized versions of this function. They will be added next to the original
36 // version (or replace it).
37 TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
38 };
39
40 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
41
42 // Traverse the function definitions and initialize the map. Allows visitAggregate to have access
43 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)44 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
45 {
46 TIntermSequence &sequence = *root->getSequence();
47
48 for (TIntermNode *node : sequence)
49 {
50 TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
51 if (asFuncDef != nullptr)
52 {
53 const TFunction *function = asFuncDef->getFunction();
54 ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
55 (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
56 }
57 }
58 }
59
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)60 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
61 {
62 *isSamplerInStructOut = false;
63
64 while (node->getAsBinaryNode())
65 {
66 TIntermBinary *asBinary = node->getAsBinaryNode();
67
68 TOperator op = asBinary->getOp();
69
70 // No opaque uniform can be inside an interface block.
71 if (op == EOpIndexDirectInterfaceBlock)
72 {
73 return nullptr;
74 }
75
76 if (op == EOpIndexDirectStruct)
77 {
78 *isSamplerInStructOut = true;
79 }
80
81 node = asBinary->getLeft();
82 }
83
84 // Only interested in uniform opaque types. If a function call within another function uses
85 // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
86 // calling function is monomorphized.
87 if (node->getType().getQualifier() != EvqUniform)
88 {
89 return nullptr;
90 }
91
92 ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
93 node->getType().isStructureContainingSamplers());
94
95 TIntermSymbol *asSymbol = node->getAsSymbolNode();
96 ASSERT(asSymbol);
97
98 return &asSymbol->variable();
99 }
100
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)101 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
102 TIntermTyped *node,
103 TIntermSequence *replacementIndices)
104 {
105 TIntermTyped *withoutSideEffects = node->deepCopy();
106
107 for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
108 asBinary = asBinary->getLeft()->getAsBinaryNode())
109 {
110 TOperator op = asBinary->getOp();
111 TIntermTyped *index = asBinary->getRight();
112
113 if (op == EOpIndexDirectStruct)
114 {
115 break;
116 }
117
118 // No side effects with constant expressions.
119 if (op == EOpIndexDirect)
120 {
121 ASSERT(index->getAsConstantUnion());
122 continue;
123 }
124
125 ASSERT(op == EOpIndexIndirect);
126
127 // If the index is a symbol, there's no side effect, so leave it as-is.
128 if (index->getAsSymbolNode())
129 {
130 continue;
131 }
132
133 // Otherwise create a temp variable initialized with the index and use that temp variable as
134 // the index.
135 TIntermDeclaration *tempDecl = nullptr;
136 TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
137
138 replacementIndices->push_back(tempDecl);
139 asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
140 }
141
142 return withoutSideEffects;
143 }
144
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)145 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
146 const TVector<Argument> &replacedArguments,
147 TIntermSequence *substituteArgsOut)
148 {
149 size_t nextReplacedArg = 0;
150 for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
151 {
152 if (nextReplacedArg >= replacedArguments.size() ||
153 argIndex != replacedArguments[nextReplacedArg].argumentIndex)
154 {
155 // Not replaced, keep argument as is.
156 substituteArgsOut->push_back(originalCallArguments[argIndex]);
157 }
158 else
159 {
160 TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
161
162 // Iterate over indices of the argument and create a new arg for every non-const
163 // index. Note that the index itself may be an expression, and it may require further
164 // substitution in the next pass.
165 while (argument->getAsBinaryNode())
166 {
167 TIntermBinary *asBinary = argument->getAsBinaryNode();
168 if (asBinary->getOp() == EOpIndexIndirect)
169 {
170 TIntermTyped *index = asBinary->getRight();
171 substituteArgsOut->push_back(index->deepCopy());
172 }
173 argument = asBinary->getLeft();
174 }
175
176 ++nextReplacedArg;
177 }
178 }
179 }
180
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)181 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
182 const TFunction *original,
183 TVector<Argument> *replacedArguments,
184 VariableReplacementMap *argumentMapOut)
185 {
186 TFunction *substituteFunction =
187 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
188 &original->getReturnType(), original->isKnownToNotHaveSideEffects());
189
190 size_t nextReplacedArg = 0;
191 for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
192 {
193 const TVariable *originalParam = original->getParam(paramIndex);
194
195 if (nextReplacedArg >= replacedArguments->size() ||
196 paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
197 {
198 TVariable *substituteArgument =
199 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
200 originalParam->symbolType());
201 // Not replaced, add an identical parameter.
202 substituteFunction->addParameter(substituteArgument);
203 (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
204 }
205 else
206 {
207 TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
208 (*argumentMapOut)[originalParam] = substituteArgument;
209
210 // Iterate over indices of the argument and create a new parameter for every non-const
211 // index (which may be an expression). Replace the symbol in the argument with a
212 // variable of the index type. This is later used to replace the parameter in the
213 // function body.
214 while (substituteArgument->getAsBinaryNode())
215 {
216 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
217 if (asBinary->getOp() == EOpIndexIndirect)
218 {
219 TIntermTyped *index = asBinary->getRight();
220 TType *indexType = new TType(index->getType());
221 indexType->setQualifier(EvqParamIn);
222
223 TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
224 SymbolType::AngleInternal);
225 substituteFunction->addParameter(param);
226
227 // The argument now uses the function parameters as indices.
228 asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
229 }
230 substituteArgument = asBinary->getLeft();
231 }
232
233 ++nextReplacedArg;
234 }
235 }
236
237 return substituteFunction;
238 }
239
240 class MonomorphizeTraverser final : public TIntermTraverser
241 {
242 public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,FunctionMap * functionMap)243 explicit MonomorphizeTraverser(TCompiler *compiler,
244 TSymbolTable *symbolTable,
245 UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,
246 FunctionMap *functionMap)
247 : TIntermTraverser(true, false, false, symbolTable),
248 mCompiler(compiler),
249 mUnsupportedFunctionArgs(unsupportedFunctionArgs),
250 mFunctionMap(functionMap)
251 {}
252
visitAggregate(Visit visit,TIntermAggregate * node)253 bool visitAggregate(Visit visit, TIntermAggregate *node) override
254 {
255 if (node->getOp() != EOpCallFunctionInAST)
256 {
257 return true;
258 }
259
260 const TFunction *function = node->getFunction();
261 ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
262
263 FunctionData &data = (*mFunctionMap)[function];
264
265 TIntermFunctionDefinition *monomorphized =
266 processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
267 if (monomorphized)
268 {
269 data.monomorphizedDefinitions.push_back(monomorphized);
270 }
271
272 return true;
273 }
274
getAnyMonomorphized() const275 bool getAnyMonomorphized() const { return mAnyMonomorphized; }
276
277 private:
isUnsupportedArgument(TIntermTyped * callArgument,const TVariable * funcArgument) const278 bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const
279 {
280 // Only interested in opaque uniforms and structs that contain samplers.
281 const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
282 const bool isStructContainingSamplers =
283 funcArgument->getType().isStructureContainingSamplers();
284 if (!isOpaqueType && !isStructContainingSamplers)
285 {
286 return false;
287 }
288
289 // If not uniform (the variable was itself a function parameter), don't process it in
290 // this pass, as we don't know which actual uniform it corresponds to.
291 bool isSamplerInStruct = false;
292 const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
293 if (uniform == nullptr)
294 {
295 return false;
296 }
297
298 const TType &type = uniform->getType();
299
300 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers])
301 {
302 // Monomorphize if the parameter is a structure that contains samplers (so in
303 // RewriteStructSamplers we don't need to rewrite the functions to accept multiple
304 // parameters split from the struct).
305 if (isStructContainingSamplers)
306 {
307 return true;
308 }
309 }
310
311 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage])
312 {
313 // Monomorphize if:
314 //
315 // - The opaque uniform is a sampler in a struct (which can create an array-of-array
316 // situation), and the function expects an array of samplers, or
317 //
318 // - The opaque uniform is an array of array of sampler or image, and it's partially
319 // subscripted (i.e. the function itself expects an array)
320 //
321 const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
322 const bool isArrayOfArrayOfSamplerOrImage =
323 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
324 if (isSamplerInStruct && isParameterArrayOfOpaqueType)
325 {
326 return true;
327 }
328 if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType)
329 {
330 return true;
331 }
332 }
333
334 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter])
335 {
336 if (type.isAtomicCounter())
337 {
338 return true;
339 }
340 }
341
342 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image])
343 {
344 if (type.isImage())
345 {
346 return true;
347 }
348 }
349
350 if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage])
351 {
352 if (type.isPixelLocal())
353 {
354 return true;
355 }
356 }
357
358 return false;
359 }
360
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)361 TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
362 TIntermFunctionDefinition *originalDefinition,
363 bool *isOriginalUsedOut)
364 {
365 const TFunction *function = functionCall->getFunction();
366 const TIntermSequence &callArguments = *functionCall->getSequence();
367
368 TVector<Argument> replacedArguments;
369 TIntermSequence replacementIndices;
370
371 // Go through function call arguments, and see if any is used in an unsupported way.
372 for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
373 {
374 TIntermTyped *callArgument = callArguments[argIndex]->getAsTyped();
375 const TVariable *funcArgument = function->getParam(argIndex);
376 if (isUnsupportedArgument(callArgument, funcArgument))
377 {
378 // Copy the argument and extract the side effects.
379 TIntermTyped *argument =
380 ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
381
382 replacedArguments.push_back({argIndex, argument});
383 }
384 }
385
386 if (replacedArguments.empty())
387 {
388 *isOriginalUsedOut = true;
389 return nullptr;
390 }
391
392 mAnyMonomorphized = true;
393
394 insertStatementsInParentBlock(replacementIndices);
395
396 // Create the arguments for the substitute function call. Done before monomorphizing the
397 // function, which transforms the arguments to what needs to be replaced in the function
398 // body.
399 TIntermSequence newCallArgs;
400 CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
401
402 // Duplicate the function and substitute the replaced arguments with only the non-const
403 // indices. Additionally, substitute the non-const indices of arguments with the new
404 // function parameters.
405 VariableReplacementMap argumentMap;
406 const TFunction *monomorphized =
407 MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
408
409 // Replace this function call with a call to the new one.
410 queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
411 OriginalNode::IS_DROPPED);
412
413 // Create a new function definition, with the body of the old function but with the replaced
414 // parameters substituted with the calling expressions.
415 TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
416 TIntermBlock *substituteBlock = originalDefinition->getBody()->deepCopy();
417 GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
418 bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
419 ASSERT(valid);
420
421 return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
422 }
423
424 TCompiler *mCompiler;
425 UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs;
426 bool mAnyMonomorphized = false;
427
428 // Map of original to monomorphized functions.
429 FunctionMap *mFunctionMap;
430 };
431
432 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
433 {
434 public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)435 explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
436 const FunctionMap &functionMap)
437 : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
438 {}
439
visitFunctionPrototype(TIntermFunctionPrototype * node)440 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
441 {
442 const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
443 if (isInFunctionDefinition)
444 {
445 return;
446 }
447
448 // Add to and possibly replace the function prototype with replacement prototypes.
449 const TFunction *function = node->getFunction();
450 ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
451
452 const FunctionData &data = mFunctionMap.at(function);
453
454 // If nothing to do, leave it be.
455 if (data.monomorphizedDefinitions.empty())
456 {
457 ASSERT(data.isOriginalUsed || function->isMain());
458 return;
459 }
460
461 // Replace the prototype with itself (if function is still used) as well as any
462 // monomorphized versions.
463 TIntermSequence replacement;
464 if (data.isOriginalUsed)
465 {
466 replacement.push_back(node);
467 }
468 for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
469 {
470 replacement.push_back(new TIntermFunctionPrototype(
471 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
472 }
473 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
474 std::move(replacement));
475 }
476
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)477 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
478 {
479 // Add to and possibly replace the function definition with replacement definitions.
480 const TFunction *function = node->getFunction();
481 ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
482
483 const FunctionData &data = mFunctionMap.at(function);
484
485 // If nothing to do, leave it be.
486 if (data.monomorphizedDefinitions.empty())
487 {
488 ASSERT(data.isOriginalUsed || function->isMain());
489 return false;
490 }
491
492 // Replace the definition with itself (if function is still used) as well as any
493 // monomorphized versions.
494 TIntermSequence replacement;
495 if (data.isOriginalUsed)
496 {
497 replacement.push_back(node);
498 }
499 for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
500 {
501 replacement.push_back(monomorphizedDefinition);
502 }
503 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
504 std::move(replacement));
505
506 return false;
507 }
508
509 private:
510 const FunctionMap &mFunctionMap;
511 };
512
SortDeclarations(TIntermBlock * root)513 void SortDeclarations(TIntermBlock *root)
514 {
515 TIntermSequence *original = root->getSequence();
516
517 TIntermSequence replacement;
518 TIntermSequence functionDefs;
519
520 // Accumulate non-function-definition declarations in |replacement| and function definitions in
521 // |functionDefs|.
522 for (TIntermNode *node : *original)
523 {
524 if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
525 {
526 functionDefs.push_back(node);
527 }
528 else
529 {
530 replacement.push_back(node);
531 }
532 }
533
534 // Append function definitions to |replacement|.
535 replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
536
537 // Replace root's sequence with |replacement|.
538 root->replaceAllChildren(replacement);
539 }
540
MonomorphizeUnsupportedFunctionsImpl(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)541 bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler,
542 TIntermBlock *root,
543 TSymbolTable *symbolTable,
544 UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
545 {
546 // First, sort out the declarations such that all non-function declarations are placed before
547 // function definitions. This way when the function is replaced with one that references said
548 // declarations (i.e. uniforms), the uniform declaration is already present above it.
549 SortDeclarations(root);
550
551 while (true)
552 {
553 FunctionMap functionMap;
554 InitializeFunctionMap(root, &functionMap);
555
556 MonomorphizeTraverser monomorphizer(compiler, symbolTable, unsupportedFunctionArgs,
557 &functionMap);
558 root->traverse(&monomorphizer);
559
560 if (!monomorphizer.getAnyMonomorphized())
561 {
562 break;
563 }
564
565 if (!monomorphizer.updateTree(compiler, root))
566 {
567 return false;
568 }
569
570 UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
571 root->traverse(&functionUpdater);
572
573 if (!functionUpdater.updateTree(compiler, root))
574 {
575 return false;
576 }
577 }
578
579 return true;
580 }
581 } // anonymous namespace
582
MonomorphizeUnsupportedFunctions(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)583 bool MonomorphizeUnsupportedFunctions(TCompiler *compiler,
584 TIntermBlock *root,
585 TSymbolTable *symbolTable,
586 UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
587 {
588 // This function actually applies multiple transformation, and the AST may not be valid until
589 // the transformations are entirely done. Some validation is momentarily disabled.
590 bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
591
592 bool result =
593 MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, unsupportedFunctionArgs);
594
595 compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
596 return result && compiler->validateAST(root);
597 }
598 } // namespace sh
599