xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/RewritePipelines.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 #include "compiler/translator/IntermRebuild.h"
12 #include "compiler/translator/msl/AstHelpers.h"
13 #include "compiler/translator/msl/DiscoverDependentFunctions.h"
14 #include "compiler/translator/msl/IdGen.h"
15 #include "compiler/translator/msl/MapSymbols.h"
16 #include "compiler/translator/msl/Pipeline.h"
17 #include "compiler/translator/msl/RewritePipelines.h"
18 #include "compiler/translator/msl/SymbolEnv.h"
19 #include "compiler/translator/msl/TranslatorMSL.h"
20 #include "compiler/translator/tree_ops/PruneNoOps.h"
21 #include "compiler/translator/tree_util/DriverUniform.h"
22 #include "compiler/translator/tree_util/FindMain.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/util.h"
25 
26 using namespace sh;
27 
28 ////////////////////////////////////////////////////////////////////////////////
29 
30 namespace
31 {
32 
IsVariableInvariant(const std::vector<sh::ShaderVariable> & mVars,const ImmutableString & name)33 bool IsVariableInvariant(const std::vector<sh::ShaderVariable> &mVars, const ImmutableString &name)
34 {
35     for (const auto &var : mVars)
36     {
37         if (name == var.name)
38         {
39             return var.isInvariant;
40         }
41     }
42     // TODO(kpidington): this should be UNREACHABLE() but isn't because the translator generates
43     // declarations to unused built-in variables.
44     return false;
45 }
46 
47 using VariableSet  = std::unordered_set<const TVariable *>;
48 using VariableList = std::vector<const TVariable *>;
49 
50 ////////////////////////////////////////////////////////////////////////////////
51 
52 struct PipelineStructInfo
53 {
54     VariableSet pipelineVariables;
55     PipelineScoped<TStructure> pipelineStruct;
56     const TFunction *funcOriginalToModified = nullptr;
57     const TFunction *funcModifiedToOriginal = nullptr;
58 
isEmpty__anon7ec8f16e0111::PipelineStructInfo59     bool isEmpty() const
60     {
61         if (pipelineStruct.isTotallyEmpty())
62         {
63             ASSERT(pipelineVariables.empty());
64             return true;
65         }
66         else
67         {
68             ASSERT(pipelineStruct.isTotallyFull());
69             ASSERT(!pipelineVariables.empty());
70             return false;
71         }
72     }
73 };
74 
75 class GeneratePipelineStruct : private TIntermRebuild
76 {
77   private:
78     const Pipeline &mPipeline;
79     SymbolEnv &mSymbolEnv;
80     const std::vector<sh::ShaderVariable> *mVariableInfos;
81     VariableList mPipelineVariableList;
82     IdGen &mIdGen;
83     PipelineStructInfo mInfo;
84 
85   public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)86     static bool Exec(PipelineStructInfo &out,
87                      TCompiler &compiler,
88                      TIntermBlock &root,
89                      IdGen &idGen,
90                      const Pipeline &pipeline,
91                      SymbolEnv &symbolEnv,
92                      const std::vector<sh::ShaderVariable> *variableInfos)
93     {
94         GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, variableInfos);
95         if (!self.exec(root))
96         {
97             return false;
98         }
99         out = self.mInfo;
100         return true;
101     }
102 
103   private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)104     GeneratePipelineStruct(TCompiler &compiler,
105                            IdGen &idGen,
106                            const Pipeline &pipeline,
107                            SymbolEnv &symbolEnv,
108                            const std::vector<sh::ShaderVariable> *variableInfos)
109         : TIntermRebuild(compiler, true, true),
110           mPipeline(pipeline),
111           mSymbolEnv(symbolEnv),
112           mVariableInfos(variableInfos),
113           mIdGen(idGen)
114     {}
115 
exec(TIntermBlock & root)116     bool exec(TIntermBlock &root)
117     {
118         if (!rebuildRoot(root))
119         {
120             return false;
121         }
122 
123         if (mInfo.pipelineVariables.empty())
124         {
125             return true;
126         }
127 
128         TIntermSequence seq;
129 
130         const TStructure &pipelineStruct = [&]() -> const TStructure & {
131             if (mPipeline.globalInstanceVar)
132             {
133                 return *mPipeline.globalInstanceVar->getType().getStruct();
134             }
135             else
136             {
137                 return createInternalPipelineStruct(root, seq);
138             }
139         }();
140 
141         ModifiedStructMachineries modifiedMachineries;
142         const bool isUBO     = mPipeline.type == Pipeline::Type::UniformBuffer;
143         const bool isUniform = mPipeline.type == Pipeline::Type::UniformBuffer ||
144                                mPipeline.type == Pipeline::Type::UserUniforms;
145         const bool useAttributeAliasing =
146             mPipeline.type == Pipeline::Type::VertexIn && mCompiler.supportsAttributeAliasing();
147         const bool modified = TryCreateModifiedStruct(
148             mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
149             mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
150             !isUniform, useAttributeAliasing);
151 
152         if (modified)
153         {
154             ASSERT(mPipeline.type != Pipeline::Type::Texture);
155             ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
156                    !mPipeline.globalInstanceVar);  // This shouldn't happen by construction.
157 
158             auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
159                 return funcDecl ? funcDecl->getFunction() : nullptr;
160             };
161 
162             const size_t size = modifiedMachineries.size();
163             ASSERT(size > 0);
164             for (size_t i = 0; i < size; ++i)
165             {
166                 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
167                 ASSERT(machinery.modifiedStruct);
168 
169                 seq.push_back(new TIntermDeclaration{
170                     &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
171 
172                 if (mPipeline.isPipelineOut())
173                 {
174                     ASSERT(machinery.funcOriginalToModified);
175                     ASSERT(!machinery.funcModifiedToOriginal);
176                     seq.push_back(machinery.funcOriginalToModified);
177                 }
178                 else
179                 {
180                     ASSERT(machinery.funcModifiedToOriginal);
181                     ASSERT(!machinery.funcOriginalToModified);
182                     seq.push_back(machinery.funcModifiedToOriginal);
183                 }
184 
185                 if (i == size - 1)
186                 {
187                     mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
188                     mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
189 
190                     mInfo.pipelineStruct.internal = &pipelineStruct;
191                     mInfo.pipelineStruct.external =
192                         modified ? machinery.modifiedStruct : &pipelineStruct;
193                 }
194             }
195         }
196         else
197         {
198             mInfo.pipelineStruct.internal = &pipelineStruct;
199             mInfo.pipelineStruct.external = &pipelineStruct;
200         }
201 
202         if (mPipeline.type == Pipeline::Type::FragmentOut &&
203             mCompiler.hasPixelLocalStorageUniforms() &&
204             mCompiler.getPixelLocalStorageType() == ShPixelLocalStorageType::FramebufferFetch)
205         {
206             auto &fields = *new TFieldList();
207             for (const TField *field : mInfo.pipelineStruct.external->fields())
208             {
209                 if (field->type()->getQualifier() == EvqFragmentInOut)
210                 {
211                     fields.push_back(new TField(&CloneType(*field->type()), field->name(),
212                                                 kNoSourceLoc, field->symbolType()));
213                 }
214             }
215             TStructure *extraStruct =
216                 new TStructure(&mSymbolTable, ImmutableString("LastFragmentOut"), &fields,
217                                SymbolType::AngleInternal);
218             seq.push_back(
219                 new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, *extraStruct)});
220             mInfo.pipelineStruct.externalExtra = extraStruct;
221         }
222 
223         root.insertChildNodes(FindMainIndex(&root), seq);
224 
225         return true;
226     }
227 
228   private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)229     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
230     {
231         return {node, VisitBits::Neither};
232     }
visitDeclarationPost(TIntermDeclaration & declNode)233     PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
234     {
235         Declaration decl     = ViewDeclaration(declNode);
236         const TVariable &var = decl.symbol.variable();
237         if (mPipeline.uses(var))
238         {
239             ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
240             mInfo.pipelineVariables.insert(&var);
241             mPipelineVariableList.push_back(&var);
242             return nullptr;
243         }
244 
245         return declNode;
246     }
247 
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)248     const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
249     {
250         auto &fields = *new TFieldList();
251 
252         switch (mPipeline.type)
253         {
254             case Pipeline::Type::Texture:
255             {
256                 for (const TVariable *var : mPipelineVariableList)
257                 {
258                     const TType &varType         = var->getType();
259                     const TBasicType samplerType = varType.getBasicType();
260 
261                     const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
262                     auto *textureEnvType         = new TType(&textureEnv, false);
263                     if (varType.isArray())
264                     {
265                         textureEnvType->makeArrays(varType.getArraySizes());
266                     }
267 
268                     fields.push_back(
269                         new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
270                 }
271             }
272             break;
273 
274             case Pipeline::Type::Image:
275             {
276                 for (const TVariable *var : mPipelineVariableList)
277                 {
278                     auto &type  = CloneType(var->getType());
279                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
280                     fields.push_back(field);
281                 }
282             }
283             break;
284 
285             case Pipeline::Type::UniformBuffer:
286             {
287                 for (const TVariable *var : mPipelineVariableList)
288                 {
289                     auto &type  = CloneType(var->getType());
290                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
291                     mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
292                     mSymbolEnv.markAsUBO(*field);
293                     mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
294                     fields.push_back(field);
295                 }
296             }
297             break;
298             default:
299             {
300                 for (const TVariable *var : mPipelineVariableList)
301                 {
302                     auto &type = CloneType(var->getType());
303                     if (mVariableInfos && IsVariableInvariant(*mVariableInfos, var->name()))
304                     {
305                         type.setInvariant(true);
306                     }
307                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
308                     fields.push_back(field);
309                 }
310             }
311             break;
312         }
313 
314         Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
315         auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
316                                   pipelineStructName.symbolType());
317 
318         outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
319 
320         return s;
321     }
322 };
323 
324 ////////////////////////////////////////////////////////////////////////////////
325 
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)326 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
327                                                      const Pipeline &pipeline,
328                                                      PipelineScoped<TStructure> pipelineStruct)
329 {
330     ASSERT(pipelineStruct.isTotallyFull());
331 
332     PipelineScoped<TVariable> pipelineMainLocalVar;
333 
334     auto populateExternalMainLocalVar = [&]() {
335         ASSERT(!pipelineMainLocalVar.external);
336         pipelineMainLocalVar.external = &CreateInstanceVariable(
337             symbolTable, *pipelineStruct.external,
338             pipeline.getStructInstanceName(pipelineStruct.isUniform()
339                                                ? Pipeline::Variant::Original
340                                                : Pipeline::Variant::Modified));
341     };
342 
343     auto populateDistinctInternalMainLocalVar = [&]() {
344         ASSERT(!pipelineMainLocalVar.internal);
345         pipelineMainLocalVar.internal =
346             &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
347                                     pipeline.getStructInstanceName(Pipeline::Variant::Original));
348     };
349 
350     if (pipeline.type == Pipeline::Type::InstanceId)
351     {
352         populateDistinctInternalMainLocalVar();
353     }
354     else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
355     {
356         populateExternalMainLocalVar();
357 
358         if (pipelineStruct.isUniform())
359         {
360             pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
361         }
362         else
363         {
364             populateDistinctInternalMainLocalVar();
365         }
366     }
367     else if (!pipelineStruct.isUniform())
368     {
369         populateDistinctInternalMainLocalVar();
370     }
371 
372     return pipelineMainLocalVar;
373 }
374 
375 class PipelineFunctionEnv
376 {
377   private:
378     TCompiler &mCompiler;
379     SymbolEnv &mSymbolEnv;
380     TSymbolTable &mSymbolTable;
381     IdGen &mIdGen;
382     const Pipeline &mPipeline;
383     const std::unordered_set<const TFunction *> &mPipelineFunctions;
384     const PipelineScoped<TStructure> mPipelineStruct;
385     PipelineScoped<TVariable> &mPipelineMainLocalVar;
386     size_t mFirstParamIdxInMainFn = 0;
387 
388     std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
389 
390   public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)391     PipelineFunctionEnv(TCompiler &compiler,
392                         SymbolEnv &symbolEnv,
393                         IdGen &idGen,
394                         const Pipeline &pipeline,
395                         const std::unordered_set<const TFunction *> &pipelineFunctions,
396                         PipelineScoped<TStructure> pipelineStruct,
397                         PipelineScoped<TVariable> &pipelineMainLocalVar)
398         : mCompiler(compiler),
399           mSymbolEnv(symbolEnv),
400           mSymbolTable(symbolEnv.symbolTable()),
401           mIdGen(idGen),
402           mPipeline(pipeline),
403           mPipelineFunctions(pipelineFunctions),
404           mPipelineStruct(pipelineStruct),
405           mPipelineMainLocalVar(pipelineMainLocalVar)
406     {}
407 
isOriginalPipelineFunction(const TFunction & func) const408     bool isOriginalPipelineFunction(const TFunction &func) const
409     {
410         return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
411     }
412 
isUpdatedPipelineFunction(const TFunction & func) const413     bool isUpdatedPipelineFunction(const TFunction &func) const
414     {
415         auto it = mFuncMap.find(&func);
416         if (it == mFuncMap.end())
417         {
418             return false;
419         }
420         return &func == it->second;
421     }
422 
getUpdatedFunction(const TFunction & func)423     const TFunction &getUpdatedFunction(const TFunction &func)
424     {
425         ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
426 
427         const TFunction *newFunc;
428 
429         auto it = mFuncMap.find(&func);
430         if (it == mFuncMap.end())
431         {
432             const bool isMain = func.isMain();
433             if (isMain)
434             {
435                 mFirstParamIdxInMainFn = func.getParamCount();
436             }
437 
438             if (isMain && mPipeline.isPipelineOut())
439             {
440                 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
441                 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
442                                                             *mPipelineStruct.external);
443                 if (mPipeline.type == Pipeline::Type::FragmentOut &&
444                     mCompiler.hasPixelLocalStorageUniforms() &&
445                     mCompiler.getPixelLocalStorageType() ==
446                         ShPixelLocalStorageType::FramebufferFetch)
447                 {
448                     // Add an input argument to main() that contains the current framebuffer
449                     // attachment values, for loading pixel local storage.
450                     TType *type = new TType(mPipelineStruct.externalExtra, true);
451                     TVariable *lastFragmentOut =
452                         new TVariable(&mSymbolTable, ImmutableString("lastFragmentOut"), type,
453                                       SymbolType::AngleInternal);
454                     newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, *newFunc,
455                                                             *lastFragmentOut);
456                     // Initialize the main local variable with the current framebuffer contents.
457                     mPipelineMainLocalVar.externalExtra = lastFragmentOut;
458                 }
459             }
460             else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals))
461             {
462                 ASSERT(mPipelineStruct.external->fields().size() == 1);
463                 ASSERT(mPipelineStruct.external->fields()[0]->type()->getQualifier() ==
464                        TQualifier::EvqVertexID);
465                 auto *vertexIDMetalVar =
466                     new TVariable(&mSymbolTable, ImmutableString("vertexIDMetal"),
467                                   new TType(TBasicType::EbtUInt), SymbolType::AngleInternal);
468                 newFunc                        = &func;
469                 mPipelineMainLocalVar.external = vertexIDMetalVar;
470             }
471             else if (isMain && (mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
472             {
473                 std::vector<const TVariable *> variables;
474                 for (const TField *field : mPipelineStruct.external->fields())
475                 {
476                     variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
477                                                       field->symbolType()));
478                 }
479                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
480             }
481             else if (isMain && mPipeline.type == Pipeline::Type::Texture)
482             {
483                 std::vector<const TVariable *> variables;
484                 TranslatorMetalReflection *reflection =
485                     mtl::getTranslatorMetalReflection(&mCompiler);
486                 for (const TField *field : mPipelineStruct.external->fields())
487                 {
488                     const TStructure *textureEnv = field->type()->getStruct();
489                     ASSERT(textureEnv && textureEnv->fields().size() == 2);
490                     for (const TField *subfield : textureEnv->fields())
491                     {
492                         const Name name = mIdGen.createNewName({field->name(), subfield->name()});
493                         TType &type     = *new TType(*subfield->type());
494                         ASSERT(!type.isArray());
495                         type.makeArrays(field->type()->getArraySizes());
496                         auto *var =
497                             new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
498                         variables.push_back(var);
499                         reflection->addOriginalName(var->uniqueId().get(), field->name().data());
500                     }
501                 }
502                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
503             }
504             else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
505             {
506                 Name instanceIdName = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
507                 auto *instanceIdVar =
508                     new TVariable(&mSymbolTable, instanceIdName.rawName(),
509                                   new TType(TBasicType::EbtUInt), instanceIdName.symbolType());
510 
511                 auto *baseInstanceVar =
512                     new TVariable(&mSymbolTable, kBaseInstanceName.rawName(),
513                                   new TType(TBasicType::EbtUInt), kBaseInstanceName.symbolType());
514 
515                 newFunc = &CloneFunctionAndPrependTwoParams(mSymbolTable, nullptr, func,
516                                                             *instanceIdVar, *baseInstanceVar);
517                 mPipelineMainLocalVar.external      = instanceIdVar;
518                 mPipelineMainLocalVar.externalExtra = baseInstanceVar;
519             }
520             else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
521             {
522                 ASSERT(mPipelineMainLocalVar.isTotallyFull());
523                 newFunc = &func;
524             }
525             else
526             {
527                 const TVariable *var;
528                 AddressSpace addressSpace;
529 
530                 if (isMain && !mPipelineMainLocalVar.isUniform())
531                 {
532                     var = &CreateInstanceVariable(
533                         mSymbolTable, *mPipelineStruct.external,
534                         mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
535                     addressSpace = mPipeline.externalAddressSpace();
536                 }
537                 else
538                 {
539                     var = &CreateInstanceVariable(
540                         mSymbolTable, *mPipelineStruct.internal,
541                         mPipeline.getStructInstanceName(Pipeline::Variant::Original));
542                     addressSpace = mPipelineMainLocalVar.isUniform()
543                                        ? mPipeline.externalAddressSpace()
544                                        : AddressSpace::Thread;
545                 }
546 
547                 bool markAsReference = true;
548                 if (isMain)
549                 {
550                     switch (mPipeline.type)
551                     {
552                         case Pipeline::Type::VertexIn:
553                         case Pipeline::Type::FragmentIn:
554                         case Pipeline::Type::Image:
555                             markAsReference = false;
556                             break;
557 
558                         default:
559                             break;
560                     }
561                 }
562 
563                 if (markAsReference)
564                 {
565                     mSymbolEnv.markAsReference(*var, addressSpace);
566                 }
567 
568                 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
569             }
570 
571             mFuncMap[&func]   = newFunc;
572             mFuncMap[newFunc] = newFunc;
573         }
574         else
575         {
576             newFunc = it->second;
577         }
578 
579         return *newFunc;
580     }
581 
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)582     TIntermFunctionPrototype *createUpdatedFunctionPrototype(
583         TIntermFunctionPrototype &funcProtoNode)
584     {
585         const TFunction &func = *funcProtoNode.getFunction();
586         if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
587         {
588             return nullptr;
589         }
590         const TFunction &newFunc = getUpdatedFunction(func);
591         return new TIntermFunctionPrototype(&newFunc);
592     }
593 
getFirstParamIdxInMainFn() const594     size_t getFirstParamIdxInMainFn() const { return mFirstParamIdxInMainFn; }
595 };
596 
597 class UpdatePipelineFunctions : private TIntermRebuild
598 {
599   private:
600     const Pipeline &mPipeline;
601     const PipelineScoped<TStructure> mPipelineStruct;
602     PipelineScoped<TVariable> &mPipelineMainLocalVar;
603     SymbolEnv &mSymbolEnv;
604     PipelineFunctionEnv mEnv;
605     const TFunction *mFuncOriginalToModified;
606     const TFunction *mFuncModifiedToOriginal;
607 
608   public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)609     static bool ThreadPipeline(TCompiler &compiler,
610                                TIntermBlock &root,
611                                const Pipeline &pipeline,
612                                const std::unordered_set<const TFunction *> &pipelineFunctions,
613                                PipelineScoped<TStructure> pipelineStruct,
614                                PipelineScoped<TVariable> &pipelineMainLocalVar,
615                                IdGen &idGen,
616                                SymbolEnv &symbolEnv,
617                                const TFunction *funcOriginalToModified,
618                                const TFunction *funcModifiedToOriginal)
619     {
620         UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
621                                      pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
622                                      funcModifiedToOriginal);
623         if (!self.rebuildRoot(root))
624         {
625             return false;
626         }
627         return true;
628     }
629 
630   private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)631     UpdatePipelineFunctions(TCompiler &compiler,
632                             const Pipeline &pipeline,
633                             const std::unordered_set<const TFunction *> &pipelineFunctions,
634                             PipelineScoped<TStructure> pipelineStruct,
635                             PipelineScoped<TVariable> &pipelineMainLocalVar,
636                             IdGen &idGen,
637                             SymbolEnv &symbolEnv,
638                             const TFunction *funcOriginalToModified,
639                             const TFunction *funcModifiedToOriginal)
640         : TIntermRebuild(compiler, false, true),
641           mPipeline(pipeline),
642           mPipelineStruct(pipelineStruct),
643           mPipelineMainLocalVar(pipelineMainLocalVar),
644           mSymbolEnv(symbolEnv),
645           mEnv(compiler,
646                symbolEnv,
647                idGen,
648                pipeline,
649                pipelineFunctions,
650                pipelineStruct,
651                mPipelineMainLocalVar),
652           mFuncOriginalToModified(funcOriginalToModified),
653           mFuncModifiedToOriginal(funcModifiedToOriginal)
654     {
655         ASSERT(mPipelineStruct.isTotallyFull());
656     }
657 
getInternalPipelineVariable(const TFunction & pipelineFunc)658     const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
659     {
660         if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
661                                       !mPipelineMainLocalVar.isUniform()))
662         {
663             ASSERT(mPipelineMainLocalVar.internal);
664             return *mPipelineMainLocalVar.internal;
665         }
666         else
667         {
668             ASSERT(pipelineFunc.getParamCount() > 0);
669             return *pipelineFunc.getParam(0);
670         }
671     }
672 
getExternalPipelineVariable(const TFunction & mainFunc)673     const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
674     {
675         ASSERT(mainFunc.isMain());
676         if (mPipelineMainLocalVar.external)
677         {
678             return *mPipelineMainLocalVar.external;
679         }
680         else
681         {
682             ASSERT(mainFunc.getParamCount() > 0);
683             return *mainFunc.getParam(0);
684         }
685     }
686 
getExternalExtraPipelineVariable(const TFunction & mainFunc)687     const TVariable &getExternalExtraPipelineVariable(const TFunction &mainFunc)
688     {
689         ASSERT(mainFunc.isMain());
690         if (mPipelineMainLocalVar.externalExtra)
691         {
692             return *mPipelineMainLocalVar.externalExtra;
693         }
694         else
695         {
696             ASSERT(mainFunc.getParamCount() > 1);
697             return *mainFunc.getParam(1);
698         }
699     }
700 
visitAggregatePost(TIntermAggregate & callNode)701     PostResult visitAggregatePost(TIntermAggregate &callNode) override
702     {
703         if (callNode.isConstructor())
704         {
705             return callNode;
706         }
707         else
708         {
709             const TFunction &oldCalledFunc = *callNode.getFunction();
710             if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
711             {
712                 return callNode;
713             }
714             const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
715 
716             const TFunction *oldOwnerFunc = getParentFunction();
717             ASSERT(oldOwnerFunc);
718             const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
719 
720             return *TIntermAggregate::CreateFunctionCall(
721                 newCalledFunc, &CloneSequenceAndPrepend(
722                                    *callNode.getSequence(),
723                                    *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
724         }
725     }
726 
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)727     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
728     {
729         TIntermFunctionPrototype *newFuncProtoNode =
730             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
731         if (newFuncProtoNode == nullptr)
732         {
733             return funcProtoNode;
734         }
735         return *newFuncProtoNode;
736     }
737 
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)738     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
739     {
740         if (funcDefNode.getFunction()->isMain())
741         {
742             return visitMain(funcDefNode);
743         }
744         else
745         {
746             return visitNonMain(funcDefNode);
747         }
748     }
749 
visitNonMain(TIntermFunctionDefinition & funcDefNode)750     TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
751     {
752         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
753         ASSERT(!funcProtoNode.getFunction()->isMain());
754 
755         TIntermFunctionPrototype *newFuncProtoNode =
756             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
757         if (newFuncProtoNode == nullptr)
758         {
759             return funcDefNode;
760         }
761 
762         const TFunction &func = *newFuncProtoNode->getFunction();
763         ASSERT(!func.isMain());
764 
765         TIntermBlock *body = funcDefNode.getBody();
766 
767         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
768     }
769 
visitMain(TIntermFunctionDefinition & funcDefNode)770     TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
771     {
772         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
773         ASSERT(funcProtoNode.getFunction()->isMain());
774 
775         TIntermFunctionPrototype *newFuncProtoNode =
776             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
777         if (newFuncProtoNode == nullptr)
778         {
779             return funcDefNode;
780         }
781 
782         const TFunction &func = *newFuncProtoNode->getFunction();
783         ASSERT(func.isMain());
784 
785         auto callModifiedToOriginal = [&](TIntermBlock &body) {
786             ASSERT(mPipelineMainLocalVar.internal);
787             if (!mPipeline.isPipelineOut())
788             {
789                 ASSERT(mFuncModifiedToOriginal);
790                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
791                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
792                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
793                     *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
794             }
795         };
796 
797         auto callOriginalToModified = [&](TIntermBlock &body) {
798             ASSERT(mPipelineMainLocalVar.internal);
799             if (mPipeline.isPipelineOut())
800             {
801                 ASSERT(mFuncOriginalToModified);
802                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
803                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
804                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
805                     *mFuncOriginalToModified, new TIntermSequence{o, m}));
806             }
807         };
808 
809         TIntermBlock *body = funcDefNode.getBody();
810 
811         if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
812         {
813             ASSERT(mPipelineMainLocalVar.isTotallyFull());
814 
815             auto *newBody = new TIntermBlock();
816             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
817 
818             if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals)
819             {
820                 ASSERT(mPipelineStruct.external->fields().size() == 1);
821                 ASSERT(mPipelineStruct.external->fields()[0]->type()->getQualifier() ==
822                        TQualifier::EvqVertexID);
823                 const TField *field = mPipelineStruct.external->fields()[0];
824                 auto *var =
825                     new TVariable(&mSymbolTable, field->name(), field->type(), field->symbolType());
826                 auto &accessNode   = AccessField(*mPipelineMainLocalVar.internal, Name(*var));
827                 auto vertexIDMetal = new TIntermSymbol(&getExternalPipelineVariable(func));
828                 auto *assignNode   = new TIntermBinary(
829                     TOperator::EOpAssign, &accessNode,
830                     &AsType(mSymbolEnv, *new TType(TBasicType::EbtInt), *vertexIDMetal));
831                 newBody->appendStatement(assignNode);
832             }
833             else if (mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
834             {
835                 // Populate struct instance with references to global pipeline variables.
836                 for (const TField *field : mPipelineStruct.external->fields())
837                 {
838                     auto *var        = new TVariable(&mSymbolTable, field->name(), field->type(),
839                                                      field->symbolType());
840                     auto *symbol     = new TIntermSymbol(var);
841                     auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, Name(*var));
842                     auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
843                     newBody->appendStatement(assignNode);
844                 }
845             }
846             else if (mPipeline.type == Pipeline::Type::FragmentOut &&
847                      mCompiler.hasPixelLocalStorageUniforms() &&
848                      mCompiler.getPixelLocalStorageType() ==
849                          ShPixelLocalStorageType::FramebufferFetch)
850             {
851                 ASSERT(mPipelineMainLocalVar.externalExtra);
852                 auto &lastFragmentOut = *mPipelineMainLocalVar.externalExtra;
853                 for (const TField *field : lastFragmentOut.getType().getStruct()->fields())
854                 {
855                     auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, Name(*field));
856                     auto &sourceNode = AccessField(lastFragmentOut, Name(*field));
857                     auto *assignNode =
858                         new TIntermBinary(TOperator::EOpAssign, &accessNode, &sourceNode);
859                     newBody->appendStatement(assignNode);
860                 }
861             }
862             else if (mPipeline.type == Pipeline::Type::Texture)
863             {
864                 const TFieldList &fields = mPipelineStruct.external->fields();
865 
866                 ASSERT(func.getParamCount() >= mEnv.getFirstParamIdxInMainFn() + 2 * fields.size());
867                 size_t paramIndex = mEnv.getFirstParamIdxInMainFn();
868 
869                 for (const TField *field : fields)
870                 {
871                     const TVariable &textureParam = *func.getParam(paramIndex++);
872                     const TVariable &samplerParam = *func.getParam(paramIndex++);
873 
874                     auto go = [&](TIntermTyped &env, const int *index) {
875                         TIntermTyped &textureField =
876                             AccessField(AccessIndex(*env.deepCopy(), index),
877                                         Name("texture", SymbolType::BuiltIn));
878                         TIntermTyped &samplerField =
879                             AccessField(AccessIndex(*env.deepCopy(), index),
880                                         Name("sampler", SymbolType::BuiltIn));
881 
882                         auto mkAssign = [&](TIntermTyped &field, const TVariable &param) {
883                             return new TIntermBinary(TOperator::EOpAssign, &field,
884                                                      &mSymbolEnv.callFunctionOverload(
885                                                          Name("addressof"), field.getType(),
886                                                          *new TIntermSequence{&AccessIndex(
887                                                              *new TIntermSymbol(&param), index)}));
888                         };
889 
890                         newBody->appendStatement(mkAssign(textureField, textureParam));
891                         newBody->appendStatement(mkAssign(samplerField, samplerParam));
892                     };
893 
894                     TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, Name(*field));
895                     const TType &envType = env.getType();
896 
897                     if (envType.isArray())
898                     {
899                         ASSERT(!envType.isArrayOfArrays());
900                         const auto n = static_cast<int>(envType.getArraySizeProduct());
901                         for (int i = 0; i < n; ++i)
902                         {
903                             go(env, &i);
904                         }
905                     }
906                     else
907                     {
908                         go(env, nullptr);
909                     }
910                 }
911             }
912             else if (mPipeline.type == Pipeline::Type::InstanceId)
913             {
914                 auto varInstanceId   = new TIntermSymbol(&getExternalPipelineVariable(func));
915                 auto varBaseInstance = new TIntermSymbol(&getExternalExtraPipelineVariable(func));
916 
917                 newBody->appendStatement(new TIntermBinary(
918                     TOperator::EOpAssign,
919                     &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
920                     &AsType(
921                         mSymbolEnv, *new TType(TBasicType::EbtInt),
922                         *new TIntermBinary(TOperator::EOpSub, varInstanceId, varBaseInstance))));
923             }
924             else if (!mPipelineMainLocalVar.isUniform())
925             {
926                 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
927                 callModifiedToOriginal(*newBody);
928             }
929 
930             newBody->appendStatement(body);
931 
932             if (!mPipelineMainLocalVar.isUniform())
933             {
934                 callOriginalToModified(*newBody);
935             }
936 
937             if (mPipeline.isPipelineOut())
938             {
939                 newBody->appendStatement(new TIntermBranch(
940                     TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
941             }
942 
943             body = newBody;
944         }
945         else if (!mPipelineMainLocalVar.isUniform())
946         {
947             ASSERT(!mPipelineMainLocalVar.external);
948             ASSERT(mPipelineMainLocalVar.internal);
949 
950             auto *newBody = new TIntermBlock();
951             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
952             callModifiedToOriginal(*newBody);
953             newBody->appendStatement(body);
954             callOriginalToModified(*newBody);
955             body = newBody;
956         }
957 
958         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
959     }
960 };
961 
962 ////////////////////////////////////////////////////////////////////////////////
963 
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)964 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
965                            TCompiler &compiler,
966                            TIntermBlock &root,
967                            SymbolEnv &symbolEnv,
968                            const VariableSet &pipelineVariables,
969                            PipelineScoped<TVariable> pipelineMainLocalVar)
970 {
971     auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
972         if (!owner)
973             return symbol;
974         const TVariable &var = symbol.variable();
975         if (pipelineVariables.find(&var) == pipelineVariables.end())
976         {
977             return symbol;
978         }
979         const TVariable *structInstanceVar;
980         if (owner->isMain() && pipelineType != Pipeline::Type::FragmentIn)
981         {
982             ASSERT(pipelineMainLocalVar.internal);
983             structInstanceVar = pipelineMainLocalVar.internal;
984         }
985         else
986         {
987             ASSERT(owner->getParamCount() > 0);
988             structInstanceVar = owner->getParam(0);
989         }
990         ASSERT(structInstanceVar);
991         return AccessField(*structInstanceVar, Name(var));
992     };
993     return MapSymbols(compiler, root, map);
994 }
995 
996 ////////////////////////////////////////////////////////////////////////////////
997 
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfo,PipelineScoped<TStructure> & outStruct)998 bool RewritePipeline(TCompiler &compiler,
999                      TIntermBlock &root,
1000                      IdGen &idGen,
1001                      const Pipeline &pipeline,
1002                      SymbolEnv &symbolEnv,
1003                      const std::vector<sh::ShaderVariable> *variableInfo,
1004                      PipelineScoped<TStructure> &outStruct)
1005 {
1006     ASSERT(outStruct.isTotallyEmpty());
1007 
1008     TSymbolTable &symbolTable = compiler.getSymbolTable();
1009 
1010     PipelineStructInfo psi;
1011     if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv,
1012                                       variableInfo))
1013     {
1014         return false;
1015     }
1016 
1017     if (psi.isEmpty())
1018     {
1019         return true;
1020     }
1021 
1022     const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
1023         return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
1024     });
1025 
1026     auto pipelineMainLocalVar =
1027         CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
1028 
1029     if (!UpdatePipelineFunctions::ThreadPipeline(
1030             compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
1031             idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
1032     {
1033         return false;
1034     }
1035 
1036     if (!pipeline.globalInstanceVar)
1037     {
1038         if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
1039                                    pipelineMainLocalVar))
1040         {
1041             return false;
1042         }
1043     }
1044 
1045     if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
1046     {
1047         return false;
1048     }
1049 
1050     outStruct = psi.pipelineStruct;
1051     return true;
1052 }
1053 
1054 }  // anonymous namespace
1055 
RewritePipelines(TCompiler & compiler,TIntermBlock & root,const std::vector<sh::ShaderVariable> & inputVaryings,const std::vector<sh::ShaderVariable> & outputVaryings,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,PipelineStructs & outStructs)1056 bool sh::RewritePipelines(TCompiler &compiler,
1057                           TIntermBlock &root,
1058                           const std::vector<sh::ShaderVariable> &inputVaryings,
1059                           const std::vector<sh::ShaderVariable> &outputVaryings,
1060                           IdGen &idGen,
1061                           DriverUniform &angleUniformsGlobalInstanceVar,
1062                           SymbolEnv &symbolEnv,
1063                           PipelineStructs &outStructs)
1064 {
1065     struct Info
1066     {
1067         Pipeline::Type pipelineType;
1068         PipelineScoped<TStructure> &outStruct;
1069         const TVariable *globalInstanceVar;
1070         const std::vector<sh::ShaderVariable> *variableInfo;
1071     };
1072 
1073     Info infos[] = {
1074         {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr, nullptr},
1075         {Pipeline::Type::Texture, outStructs.image, nullptr, nullptr},
1076         {Pipeline::Type::Image, outStructs.texture, nullptr, nullptr},
1077         {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr, nullptr},
1078         {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
1079          angleUniformsGlobalInstanceVar.getDriverUniformsVariable(), nullptr},
1080         {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr, nullptr},
1081         {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr, &inputVaryings},
1082         {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr, &outputVaryings},
1083         {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr, &inputVaryings},
1084         {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr, &outputVaryings},
1085         {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr,
1086          nullptr},
1087         {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr,
1088          &inputVaryings},
1089         {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr, nullptr},
1090     };
1091 
1092     for (Info &info : infos)
1093     {
1094         if ((compiler.getShaderType() != GL_VERTEX_SHADER &&
1095              (info.pipelineType == Pipeline::Type::VertexIn ||
1096               info.pipelineType == Pipeline::Type::VertexOut ||
1097               info.pipelineType == Pipeline::Type::InvocationVertexGlobals)) ||
1098             (compiler.getShaderType() != GL_FRAGMENT_SHADER &&
1099              (info.pipelineType == Pipeline::Type::FragmentIn ||
1100               info.pipelineType == Pipeline::Type::FragmentOut ||
1101               info.pipelineType == Pipeline::Type::InvocationFragmentGlobals)))
1102             continue;
1103 
1104         Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
1105         if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, info.variableInfo,
1106                              info.outStruct))
1107         {
1108             return false;
1109         }
1110     }
1111 
1112     return true;
1113 }
1114